import { Resolver, Variables } from '@urql/exchange-graphcache';
import { stringifyVariables } from 'urql';

interface PageResult {
  nextToken: string | null;
  prevToken: string | null;
  total: number;
  items: string[];
}

interface FieldInfo {
  fieldKey: string;
  fieldName: string;
  arguments: Variables | null;
}

interface CursorState {
  latestNextToken: string | null;
  oldestPrevToken: string | null;
  hasReachedEnd: boolean;
  hasReachedStart: boolean;
}

const MERGE_IGNORE_KEYS = ['nextToken', 'prevToken'];

// In-memory cursor state storage
const cursorStates = new Map<string, CursorState>();

const compareArgs = (
  fieldArgs: Variables,
  connectionArgs: Variables | null,
  strict = false,
): boolean => {
  // If both are empty/null/undefined, they match
  if (
    (!fieldArgs || Object.keys(fieldArgs).length === 0) &&
    (!connectionArgs || Object.keys(connectionArgs ?? {}).length === 0)
  ) {
    return true;
  }

  let ignoreKeys = strict ? [] : MERGE_IGNORE_KEYS;

  const fieldArgsKeys = Object.keys(fieldArgs ?? {}).filter(
    (key) => !ignoreKeys.includes(key),
  );
  const connectionArgsKeys = Object.keys(connectionArgs ?? {}).filter(
    (key) => !ignoreKeys.includes(key),
  );

  if (fieldArgsKeys.length !== connectionArgsKeys.length) return false;

  return fieldArgsKeys.every((key) => {
    const argA = fieldArgs[key];
    const argB = (connectionArgs ?? {})[key];

    return (
      typeof argA === typeof argB &&
      stringifyVariables(argA) === stringifyVariables(argB)
    );
  });
};

const getStateKey = (
  entityKey: string,
  fieldName: string,
  fieldArgs: Variables,
): string => {
  const relevantArgs = { ...fieldArgs };
  MERGE_IGNORE_KEYS.forEach((key) => delete relevantArgs[key]);
  const argsKey = stringifyVariables(relevantArgs);
  return `${entityKey}:${fieldName}:${argsKey}`;
};

const getCursorState = (
  entityKey: string,
  fieldName: string,
  fieldArgs: Variables,
): CursorState => {
  const stateKey = getStateKey(entityKey, fieldName, fieldArgs);
  return (
    cursorStates.get(stateKey) || {
      latestNextToken: null,
      oldestPrevToken: null,
      hasReachedEnd: false,
      hasReachedStart: false,
    }
  );
};

const setCursorState = (
  entityKey: string,
  fieldName: string,
  fieldArgs: Variables,
  state: CursorState,
) => {
  const stateKey = getStateKey(entityKey, fieldName, fieldArgs);
  cursorStates.set(stateKey, state);
};

const cursorPagination = (fieldTypename = ''): Resolver => {
  return (_, fieldArgs, cache, info) => {
    const { parentKey: entityKey, fieldName } = info;
    const allFields = cache.inspectFields(entityKey);
    const fieldInfos = allFields.filter(
      (_info) => _info.fieldName === fieldName,
    ) as FieldInfo[];

    const isCached = !!cache.resolve(entityKey, fieldName, fieldArgs);
    info.partial =
      !isCached && MERGE_IGNORE_KEYS.some((key) => !!fieldArgs[key]);

    if (fieldInfos.length === 0) {
      return undefined;
    }

    // Get current cursor state
    const cursorState = getCursorState(entityKey, fieldName, fieldArgs);

    // Determine pagination direction from arguments
    const isPaginatingBackward = !fieldArgs.nextToken && !!fieldArgs.prevToken;

    // First, process the newly fetched page to update cursor state
    const currentKey =
      (cache.resolve(entityKey, fieldName, fieldArgs) as string) ||
      // If args are empty, query might be cached with args as {} or null
      (cache.resolve(entityKey, fieldName, null) as string);

    if (currentKey) {
      const pageNextToken = cache.resolve(currentKey, 'nextToken') as
        | string
        | null;
      const pagePrevToken = cache.resolve(currentKey, 'prevToken') as
        | string
        | null;

      // If this is our first page and we started with a nextToken or startWithId,
      // initialize our cursor state with both tokens
      const isInitialLoad =
        !cursorState.latestNextToken && !cursorState.oldestPrevToken;
      if (isInitialLoad) {
        cursorState.latestNextToken = pageNextToken;
        cursorState.oldestPrevToken = pagePrevToken;
      } else {
        // Normal pagination updates
        if (!isPaginatingBackward) {
          if (!pageNextToken) {
            cursorState.hasReachedEnd = true;
          } else {
            cursorState.latestNextToken = pageNextToken;
          }
        } else {
          if (!pagePrevToken) {
            cursorState.hasReachedStart = true;
          } else {
            cursorState.oldestPrevToken = pagePrevToken;
          }
        }
      }
    }

    // Now merge all matching pages
    const result = fieldInfos.reduce<PageResult | null>((acc, fi) => {
      const { fieldKey, arguments: args } = fi;

      if (!compareArgs(fieldArgs, args, !!info.variables.skipMerge)) {
        return acc;
      }

      const key = cache.resolve(entityKey, fieldKey) as string;
      if (!key) return null;

      const data = cache.resolve(key, 'items') as string[];
      const total = cache.resolve(key, 'total') as number;
      const pageNextToken = cache.resolve(key, 'nextToken') as string | null;
      const pagePrevToken = cache.resolve(key, 'prevToken') as string | null;

      if (!data?.length) {
        return (
          acc || {
            nextToken: pageNextToken,
            prevToken: pagePrevToken,
            total: total ?? 0,
            items: [],
          }
        );
      }

      return {
        nextToken: pageNextToken,
        prevToken: pagePrevToken,
        total,
        items: [...(acc?.items ?? []), ...data],
      };
    }, null);

    if (!result)
      return {
        __typename: fieldTypename,
        nextToken: null,
        prevToken: null,
        total: 0,
        items: [],
      };

    // Save updated cursor state
    setCursorState(entityKey, fieldName, fieldArgs, cursorState);

    // Remove duplicates
    const uniqueItems = Array.from(new Set(result.items));

    return {
      __typename: fieldTypename,
      // Return edge tokens based on cursor state
      nextToken: cursorState.hasReachedEnd ? null : cursorState.latestNextToken,
      prevToken: cursorState.hasReachedStart
        ? null
        : cursorState.oldestPrevToken,
      total: result.total,
      items: uniqueItems,
    } as PageResult;
  };
};

export default cursorPagination;
