// Copyright 2021
// ThatWorks.xyz Limited

import { Box, Button, DataTable, DateInput, Select, Spinner, Text, TextInput } from 'grommet';
import { useEffect, useMemo, useState } from 'react';
import { Scatter } from 'react-chartjs-2';
import 'chart.js/auto';
import { useLazyQuery, useQuery } from '@apollo/client';
import { useStatsigClient } from '@statsig/react-bindings';
import { Colors } from '@thatworks/colors';
import debounce from 'lodash.debounce';
import { DateTime } from 'luxon';
import { gql } from '../../../../__generated__';
import { AdminUsersSearchQuery } from '../../../../__generated__/graphql';
import { useTelemetryContext } from '../../../../components/TelemetryContext';
import { useUserStateContext } from '../../../../components/UserContext';
import { GET_ORGS } from './shared-queries';
// Has to be after chart.js import
import 'chartjs-adapter-luxon';

const GET_USER_TOKENS = gql(/* GraphQL */ `
    query GetUserTokens($userId: String!, $fromDateIso: String!, $toDateIso: String!) {
        adminUserTokenUsage(userId: $userId, fromDateIso: $fromDateIso, toDateIso: $toDateIso) {
            inputTokens
            outputTokens
            model
            useCase
            timeIso
        }
    }
`);

const ADMIN_USERS_SEARCH = gql(/* GraphQL */ `
    query AdminUsersSearch($organizationId: String!, $query: String!) {
        adminUserSearch(organizationId: $organizationId, query: $query) {
            name
            email
            id
        }
    }
`);

function TokenChart(props: { userId: string; fromDateIso: string; toDateIso: string }): JSX.Element | null {
    const { postErrorMessage } = useUserStateContext();
    const { logger } = useTelemetryContext();
    const { getDynamicConfig } = useStatsigClient();

    const { data, loading } = useQuery(GET_USER_TOKENS, {
        variables: {
            userId: props.userId,
            fromDateIso: props.fromDateIso,
            toDateIso: props.toDateIso,
        },
        onError: (error) => {
            postErrorMessage({
                title: `Error`,
                shortDesc: `Failed to get chart data`,
            });
            logger.error(error.message);
        },
    });

    const groupedData = useMemo(() => {
        const dimensions = new Map<
            string,
            {
                data: {
                    timeIso: string;
                    inputTokens: number;
                    outputTokens: number;
                    totalTokens: number;
                    cost: number;
                }[];
                model: string;
                useCase: string;
            }
        >();
        data?.adminUserTokenUsage.forEach((item) => {
            const key = `${item.useCase} (${item.model})`;
            const found = dimensions.get(key);
            const modelPricing = getDynamicConfig('llm_model_pricing');
            const pricingPerMillion = modelPricing.get<{ input: number; output: number }>(item.model, {
                input: 0,
                output: 0,
            }) as { input: number; output: number };
            const inputPricePerToken = pricingPerMillion.input / 1000000;
            const outputPricePerToken = pricingPerMillion.output / 1000000;
            const inputCost = item.inputTokens * inputPricePerToken;
            const outputCost = item.outputTokens * outputPricePerToken;
            const totalCost = inputCost + outputCost;

            if (found) {
                found.data.push({
                    timeIso: item.timeIso,
                    inputTokens: item.inputTokens,
                    outputTokens: item.outputTokens,
                    totalTokens: item.inputTokens + item.outputTokens,
                    cost: totalCost,
                });
            } else {
                dimensions.set(key, {
                    data: [
                        {
                            timeIso: item.timeIso,
                            inputTokens: item.inputTokens,
                            outputTokens: item.outputTokens,
                            totalTokens: item.inputTokens + item.outputTokens,
                            cost: totalCost,
                        },
                    ],
                    model: item.model,
                    useCase: item.useCase,
                });
            }
        });

        const timeLabelsSet = new Set<string>();
        Array.from(dimensions.values()).forEach((v) => {
            v.data.sort((a, b) => DateTime.fromISO(a.timeIso).toMillis() - DateTime.fromISO(b.timeIso).toMillis());
            v.data.forEach((item) => timeLabelsSet.add(item.timeIso));
        });
        const timeLabels = Array.from(timeLabelsSet).sort(
            (a, b) => DateTime.fromISO(a).toMillis() - DateTime.fromISO(b).toMillis(),
        );

        // total up the tokens for each model/use case
        const tableData = Array.from(dimensions.values()).map((value) => {
            const inputTokens = value.data.reduce((acc, item) => acc + item.inputTokens, 0);
            const outputTokens = value.data.reduce((acc, item) => acc + item.outputTokens, 0);
            const totalTokens = value.data.reduce((acc, item) => acc + item.totalTokens, 0);
            const cost = value.data.reduce((acc, item) => acc + item.cost, 0);
            return {
                modelAndUseCase: `${value.model} (${value.useCase})`,
                inputTokens,
                outputTokens,
                totalTokens,
                cost,
            };
        });

        // total cost of everything
        const totalCost = tableData.reduce((acc, item) => acc + item.cost, 0);

        return { dimensions, timeLabels, tableData, totalCost };
    }, [data?.adminUserTokenUsage, getDynamicConfig]);

    if (loading) {
        return <Spinner />;
    }

    if (!data) {
        return null;
    }

    const chartData = {
        labels: groupedData.timeLabels,
        datasets: Array.from(groupedData.dimensions.entries()).map(([key, values], index) => ({
            label: key,
            data: values.data.map((item) => ({ x: item.timeIso, y: item.totalTokens })),
            borderColor: `hsl(${index * 137.5}, 70%, 50%)`,
            backgroundColor: `hsla(${index * 137.5}, 70%, 50%, 0.5)`,
            tension: 0.1,
        })),
    };

    return (
        <Box gap="xsmall" background={Colors.background_back} round="10px" pad="xxsmall">
            <Box pad={{ horizontal: 'xsmall', top: 'xsmall' }} direction="row" align="center">
                <Text weight="bold">
                    {`${dateFormat.format(new Date(props.fromDateIso))} - ${dateFormat.format(
                        new Date(props.toDateIso),
                    )}`}{' '}
                    • Total Cost {groupedData.totalCost.toLocaleString('en-US', { style: 'currency', currency: 'USD' })}
                </Text>
            </Box>
            <DataTable
                data={groupedData.tableData}
                columns={[
                    {
                        property: 'modelAndUseCase',
                        header: <Text size="14px">Model+Use Case</Text>,
                        render: (datum) => <Text size="14px">{datum.modelAndUseCase}</Text>,
                    },
                    {
                        property: 'inputTokens',
                        header: <Text size="14px">Input Tokens</Text>,
                        render: (datum) => <Text size="14px">{datum.inputTokens.toLocaleString()}</Text>,
                    },
                    {
                        property: 'outputTokens',
                        header: <Text size="14px">Output Tokens</Text>,
                        render: (datum) => <Text size="14px">{datum.outputTokens.toLocaleString()}</Text>,
                    },
                    {
                        property: 'totalTokens',
                        header: <Text size="14px">Total Tokens</Text>,
                        render: (datum) => <Text size="14px">{datum.totalTokens.toLocaleString()}</Text>,
                    },
                    {
                        property: 'cost',
                        header: <Text size="14px">Cost</Text>,
                        render: (datum) => (
                            <Text size="14px">
                                {datum.cost.toLocaleString('en-US', { style: 'currency', currency: 'USD' })}
                            </Text>
                        ),
                    },
                ]}
            />
            <Scatter
                data={chartData}
                options={{
                    scales: {
                        x: {
                            type: 'time',
                            time: {
                                unit: 'day',
                            },
                        },
                    },
                    responsive: true,
                }}
            />
        </Box>
    );
}

const dateFormat = new Intl.DateTimeFormat(undefined, {
    month: 'short',
    day: 'numeric',
});

function UserSelection(props: { organizationId: string; onUserSelection: (userId: string) => void }): JSX.Element {
    const [usersSearch, { loading }] = useLazyQuery(ADMIN_USERS_SEARCH);
    const [searchResults, setSearchResults] = useState<AdminUsersSearchQuery['adminUserSearch']>([]);
    const [searchValue, setSearchValue] = useState<string>('');

    const debouncedResults = useMemo(() => {
        return debounce((query: string) => {
            usersSearch({ variables: { query, organizationId: props.organizationId } }).then((r) => {
                const searchRes = r.data ? r.data.adminUserSearch : [];

                setSearchResults(searchRes);
            });
        }, 300);
    }, [props.organizationId, usersSearch]);

    useEffect(() => {
        return () => debouncedResults.cancel();
    }, [debouncedResults]);

    return (
        <Box direction="row" gap="xxsmall">
            <TextInput
                placeholder={
                    <Box direction="row" gap="xxsmall">
                        <Text size="16px">Search for user email</Text>
                    </Box>
                }
                size="16px"
                onChange={(e) => {
                    const searchText = e.target.value;
                    setSearchValue(searchText);
                    debouncedResults(searchText);
                }}
                value={searchValue}
                suggestions={searchResults.map((s) => ({ label: `${s.email}`, value: s }))}
                onSuggestionSelect={(x) => {
                    setSearchResults([]);
                    const suggested = x.suggestion.value as AdminUsersSearchQuery['adminUserSearch'][number];
                    setSearchValue(suggested.email);
                    props.onUserSelection(suggested.id);
                }}
                reverse
                icon={loading ? <Spinner /> : undefined}
            />
        </Box>
    );
}

export function OrganizationSelect(props: { onSelect: (orgId: string) => void }): JSX.Element | null {
    const { postErrorMessage } = useUserStateContext();
    const { logger } = useTelemetryContext();

    const { data, loading } = useQuery(GET_ORGS, {
        onError: (error) => {
            postErrorMessage({ title: 'Error', shortDesc: 'Failed to get organizations' });
            logger.error(error.message);
        },
    });

    if (loading) {
        return <Spinner />;
    }

    if (!data) {
        return null;
    }

    return (
        <Select
            options={data.usersOrganizations}
            valueKey={'displayName'}
            placeholder="Select organization"
            onChange={(op) => props.onSelect(op.value.id)}
        />
    );
}

export function UserTokenChart(): JSX.Element {
    const [userId, setUserId] = useState<string>();
    const [orgId, setOrgId] = useState<string>();
    const [dates, setDates] = useState<string[]>([DateTime.now().minus({ month: 1 }).toISO(), DateTime.now().toISO()]);

    return (
        <Box gap="xsmall" width="xlarge">
            <Box gap="xsmall">
                <Box gap="xsmall" width="medium">
                    <OrganizationSelect onSelect={(orgId) => setOrgId(orgId)} />
                    {orgId && <UserSelection organizationId={orgId} onUserSelection={(u) => setUserId(u)} />}
                </Box>
                <Box gap="xxsmall" direction="row" align="center">
                    <Button
                        label="This Month"
                        onClick={() => setDates([DateTime.now().startOf('month').toISO(), DateTime.now().toISO()])}
                    />
                    <Button
                        label="1 month"
                        onClick={() => setDates([DateTime.now().minus({ month: 1 }).toISO(), DateTime.now().toISO()])}
                    />
                    <Button
                        label="3 months"
                        onClick={() => setDates([DateTime.now().minus({ month: 3 }).toISO(), DateTime.now().toISO()])}
                    />
                    <DateInput
                        value={dates}
                        calendarProps={{
                            bounds: [DateTime.now().minus({ year: 1 }).toISO(), DateTime.now().toISO()],
                            firstDayOfWeek: 1,
                        }}
                        buttonProps={{
                            label: `Set custom range`,
                            icon: undefined,
                        }}
                        onChange={(event) => {
                            if (!Array.isArray(event.value)) {
                                throw new Error('Invalid date input, expected an array');
                            }
                            setDates(event.value);
                        }}
                    />
                </Box>
            </Box>
            {userId && dates.length === 2 && <TokenChart userId={userId} fromDateIso={dates[0]} toDateIso={dates[1]} />}
        </Box>
    );
}
