import merge from "lodash/merge";
import { PartialDeep } from "type-fest";
import {
  StructChatResponse,
  StructThread,
  Read,
} from "@/app/types/Thread.type";
import { StateCreator } from "zustand";
import { omit } from "lodash";

const CHATS_PER_THREAD = 4;

export interface ThreadsStoreState {
  threadsById: Record<StructThread["id"], StructThread>;
  openThreadsById: Record<StructThread["id"], StructThread>;
  setOpenThreadById: (
    threadId: StructThread["id"],
    thread: StructThread | null,
  ) => void;
  selectedThreadIds: StructThread["id"][];
  setSelectedThreadIds: (threadIds: StructThread["id"][]) => void;
  toggleSelectedThreadIds: (threadIds: StructThread["id"][]) => void;
  getThreadsById: () => Record<StructThread["id"], StructThread>;
  setThreads: (threads: StructThread[]) => void;
  setThreadById: (thread: StructThread) => void;
  moveThreadToTop: (threadId: StructThread["id"]) => void;
  deleteThreads: (threadIds: StructThread["id"][]) => void;
  updateThreads: (threads: StructThread[]) => void;
  getThreads: () => StructThread[] | [];
  insertThreadChatById: (
    threadId: StructThread["id"],
    chatMessage: StructChatResponse,
  ) => void;
  similarThreadsById: Record<string, StructThread[]>;
  setSimilarThreadsById: (threadId: string, threads: StructThread[]) => void;
  updateThreadById: (
    threadId: StructThread["id"],
    updatedFields: PartialDeep<StructThread>,
  ) => void;
  updateThreadsById: (
    threads: Record<StructThread["id"], PartialDeep<StructThread>>,
  ) => void;
  readUntilByThreadId: Record<StructThread["id"], StructChatResponse["id"]>;
  setReadUntilByThreadId: (
    threadId: StructThread["id"],
    chatId: StructChatResponse["id"],
  ) => void;
  setBulkReadUntilByThreadId: (
    threads: Record<StructThread["id"], StructChatResponse["id"]>,
  ) => void;
  removeThreadChatById: (
    threadId: StructThread["id"],
    chatId: StructChatResponse["id"],
  ) => void;
  updateThreadChatById: (
    threadId: StructThread["id"],
    chatId: StructChatResponse["id"],
    updatedChat: Partial<StructChatResponse>,
  ) => void;
  clearThreads: () => void;
}

const INITIAL_STATE = {};

export const createThreadsStore: StateCreator<ThreadsStoreState> = (
  set,
  get,
) => ({
  threadsById: INITIAL_STATE,
  openThreadsById: {},
  selectedThreadIds: [],
  getThreadsById: () => get().threadsById,
  similarThreadsById: {},
  readUntilByThreadId: {},
  setOpenThreadById: (
    threadId: StructThread["id"],
    thread: StructThread | null,
  ) => {
    set((state) => {
      if (!thread) {
        const updatedThreadsById = omit(get().openThreadsById, [threadId]);
        return { openThreadsById: updatedThreadsById };
      }
      return {
        openThreadsById: {
          ...state.openThreadsById,
          [threadId]: thread,
        },
      };
    });
  },
  setSelectedThreadIds: (threadIds: StructThread["id"][]) => {
    set({ selectedThreadIds: threadIds });
  },
  toggleSelectedThreadIds: (threadIds: StructThread["id"][]) => {
    set((state) => {
      const existingThreadIds = state.selectedThreadIds;
      threadIds.forEach((threadId) => {
        if (existingThreadIds.includes(threadId)) {
          existingThreadIds.splice(existingThreadIds.indexOf(threadId), 1);
        } else {
          existingThreadIds.push(threadId);
        }
      });
      return {
        selectedThreadIds: existingThreadIds,
      };
    });
  },
  setThreads: (threads: StructThread[]) => {
    set({
      threadsById: threads.reduce(
        (acc, thread) => {
          acc[thread.id] = thread;
          acc[thread.id].chats = thread.chats?.slice(-CHATS_PER_THREAD);

          return acc;
        },
        {} as Record<string, StructThread>,
      ),
    });
  },
  setThreadById: (thread: StructThread) => {
    set((state) => {
      return {
        threadsById: {
          ...state.threadsById,
          [thread.id]: {
            ...state.threadsById[thread.id],
            ...thread,
          },
        },
      };
    });
  },
  moveThreadToTop: (threadId: string) => {
    set((state) => {
      const existingThread = state.threadsById[threadId];
      if (!existingThread) {
        return state;
      }

      const reorderedThreadsById = Object.keys(state.threadsById)
        .filter((id) => id !== threadId)
        .reduce(
          (acc, id) => {
            acc[id] = state.threadsById[id];
            return acc;
          },
          { [threadId]: existingThread } as Record<string, StructThread>,
        );

      return {
        threadsById: reorderedThreadsById,
      };
    });
  },
  updateThreads: (threads: StructThread[]) => {
    set((state) => {
      const updatedThreadsById = { ...state.threadsById };
      threads.forEach((thread) => {
        updatedThreadsById[thread.id] = thread;
      });
      return { threadsById: updatedThreadsById };
    });
  },
  deleteThreads: (threadIds: string[]) => {
    set(() => {
      const updatedThreadsById = omit(get().threadsById, threadIds);

      return { threadsById: updatedThreadsById };
    });
  },
  getThreads: () => {
    const threadsById = get().threadsById;
    return Object.values(threadsById);
  },
  setSimilarThreadsById: (threadId: string, threads: StructThread[]) => {
    set((state) => ({
      similarThreadsById: {
        ...state.similarThreadsById,
        [threadId]: threads,
      },
    }));
  },
  insertThreadChatById: (
    threadId: StructThread["id"],
    chatMessage: StructChatResponse,
  ) => {
    set((state) => {
      const thread = state.threadsById[threadId];
      if (!thread) {
        return state;
      }

      const chatExists = thread.chats?.some(
        (chat) => chat.id === chatMessage.id,
      );

      if (chatExists) {
        return state;
      }

      const updatedChats = thread?.chats
        ? [...thread.chats, chatMessage]
        : [chatMessage];

      return {
        threadsById: {
          ...state.threadsById,
          [threadId]: {
            ...thread,
            chats: updatedChats.slice(-CHATS_PER_THREAD),
          },
        },
      };
    });
  },
  removeThreadChatById: (
    threadId: string,
    chatId: StructChatResponse["id"],
  ) => {
    set((state) => {
      const thread = state.threadsById[threadId];
      if (!thread || !thread.chats) {
        return state;
      }

      const updatedChats = thread.chats.filter((chat) => chat.id !== chatId);

      return {
        threadsById: {
          ...state.threadsById,
          [threadId]: {
            ...thread,
            chats: updatedChats,
          },
        },
      };
    });
  },
  updateThreadChatById: (
    threadId: string,
    chatId: StructChatResponse["id"],
    updatedChat: Partial<StructChatResponse>,
  ) => {
    set((state) => {
      const thread = state.threadsById[threadId];
      if (!thread || !thread.chats) {
        return state;
      }

      const chatIndex = thread.chats.findIndex((chat) => chat.id === chatId);
      if (chatIndex === -1) {
        return state;
      }

      const updatedChats = [...thread.chats];
      updatedChats[chatIndex] = {
        ...thread.chats[chatIndex],
        ...updatedChat,
      };

      return {
        threadsById: {
          ...state.threadsById,
          [threadId]: {
            ...thread,
            chats: updatedChats,
          },
        },
      };
    });
  },
  updateThreadsById: (
    threads: Record<StructThread["id"], PartialDeep<StructThread>>,
  ) => {
    set((state) => {
      const updatedThreadsById = { ...state.threadsById };
      for (let index = 0; index < Object.keys(threads).length; index++) {
        const threadId = Object.keys(threads)[index];
        const existingThread = state.threadsById[threadId];
        if (!existingThread) {
          break;
        }
        updatedThreadsById[threadId] = merge(
          {},
          existingThread,
          threads[threadId],
        );
        const updatedFields = threads[threadId];
        if (
          updatedFields?.read?.bits &&
          updatedThreadsById[threadId]?.read?.bits
        ) {
          updatedThreadsById[threadId] = {
            ...updatedThreadsById[threadId],
            read: {
              ...(updatedThreadsById[threadId].read as Read),
              bits: (updatedFields.read as Read).bits,
            },
          };
        }
      }
      return { threadsById: updatedThreadsById };
    });
  },

  updateThreadById: (
    threadId: string,
    updatedFields: PartialDeep<StructThread>,
  ) => {
    set((state) => {
      const existingThread = state.threadsById[threadId];
      if (!existingThread) {
        return state;
      }

      let updatedThread = merge({}, existingThread, updatedFields);
      // Manually override bits
      if (updatedFields?.read?.bits && updatedThread?.read?.bits) {
        updatedThread = {
          ...updatedThread,
          read: {
            ...updatedThread.read,
            bits: (updatedFields.read as Read).bits,
          },
        };
      }

      return {
        threadsById: {
          ...state.threadsById,
          [threadId]: updatedThread,
        },
      };
    });
  },
  setReadUntilByThreadId: (
    threadId: StructThread["id"],
    chatId: StructChatResponse["id"],
  ) => {
    set((state) => ({
      readUntilByThreadId: {
        ...state.readUntilByThreadId,
        [threadId]: chatId,
      },
    }));
  },
  setBulkReadUntilByThreadId: (
    threads: Record<StructThread["id"], StructChatResponse["id"]>,
  ) => {
    set((state) => ({
      readUntilByThreadId: {
        ...state.readUntilByThreadId,
        ...threads,
      },
    }));
  },
  clearThreads: () => {
    set({ threadsById: INITIAL_STATE });
  },
});
