import { produce } from 'immer'
import { findLastIndex, sumBy, last, uniq, isEqual } from 'lodash/fp'

import { ASRUpdateStreamMessage } from 'src/network/StreamClient'

import { TranscriptWord, ASRTranscriptWord, ASRChunkWordJSON, isPunctuation } from './TranscriptWord'

export interface TranscriptParagraph {
    id: string
    authorId: string
    time: number
    timeAbsolute: number
    duration: number
    words: TranscriptWord[]
    updatedAt: number
}

export interface ASRTranscriptParagraph extends TranscriptParagraph {
    uttIds: number[]
}

export interface TranscriptAuthor {
    id: string
    name: string
}

export interface Transcript {
    paragraphs: TranscriptParagraph[]
    asrParagraphs: ASRTranscriptParagraph[]
    words: Record<string, Array<TranscriptWord | ASRTranscriptWord>>
    authors: Record<string, TranscriptAuthor>
    isDiarizationSupported: boolean
    recordingStartedAt: Date
    processedTime: number
    asrProcessedTime?: number
    lastRevisionUpdateDiff?: {
        removed: Array<TranscriptParagraph | ASRTranscriptParagraph>
        added: Array<TranscriptParagraph | ASRTranscriptParagraph>
    }
    lastAsrUpdateDiff?: {
        removed: Array<TranscriptParagraph | ASRTranscriptParagraph>
        added: Array<TranscriptParagraph | ASRTranscriptParagraph>
    }
}

export interface TranscriptJSON {
    data: {
        paragraphs: Array<{
            id: string
            author_id: string
            time: number
            epoch_time: number
            duration: number
        }>
        words: Array<{
            text: string
            paragraph_id: string
            time: number
            duration: number
        }>
        authors: Array<{
            id: string
            text: string
        }>
    }
    processed_time: number
    recording_started_at: string
}

export interface ASRChunkJSON {
    rec_id: number
    utt_id: number
    rec_ended: boolean
    utt_ended: boolean
    start: number
    end: number
    word_list: ASRChunkWordJSON[]
    start_epoch: number
}

export const Transcript = {
    EMPTY: {
        words: {},
        paragraphs: [],
        authors: {},
        isDiarizationSupported: false,
        asrParagraphs: [],
        recordingStartedAt: new Date(),
        processedTime: 0,
    } as Transcript,

    isValidJSON(json: TranscriptJSON): boolean {
        return Array.isArray(json?.data?.words) && Array.isArray(json?.data?.paragraphs) && Array.isArray(json?.data?.authors)
    },

    fromJSON(json: TranscriptJSON): Transcript {
        const words = json.data.words.reduce<Transcript['words']>((map, w) => {
            map[w.paragraph_id] = (map[w.paragraph_id] ?? []).concat({ ...w, paragraphId: w.paragraph_id })

            return map
        }, {})
        const isDiarizationSupported = !!json.data.authors.filter(({ text }) => !!text).length

        return {
            words,
            paragraphs: json.data.paragraphs
                .filter((p) => p.duration)
                .map(({ author_id, ...p }) => ({
                    ...p,
                    authorId: author_id,
                    words: words[p.id],
                    updatedAt: Date.now(),
                    timeAbsolute: p.epoch_time,
                })),
            asrParagraphs: [],
            authors: json.data.authors.reduce<Transcript['authors']>((map, a) => {
                map[a.id] = { ...a, name: a.text || '>>' }

                return map
            }, {}),
            isDiarizationSupported,
            processedTime: json.processed_time,
            recordingStartedAt: new Date(json.recording_started_at),
        }
    },

    getParagraphWords(transcript: Transcript, paragraphId: string) {
        return transcript.words[paragraphId] ?? []
    },

    getTotalParagraphCount: ({ paragraphs, asrParagraphs }: Transcript) => paragraphs.length + asrParagraphs.length,

    getTokensAmount: (transcript: Transcript) => {
        const getCharactersAmount = (paragraphs: (TranscriptParagraph | ASRTranscriptParagraph)[]) =>
            paragraphs.reduce((totalCount, paragraph) => {
                const words = Transcript.getParagraphWords(transcript, paragraph.id)
                const wordsAsText = words.map(({ text }) => text).join(' ')

                return totalCount + wordsAsText.length
            }, 0)
        const charactersAmount = getCharactersAmount(transcript.paragraphs) + getCharactersAmount(transcript.asrParagraphs)

        // 1 token ~= 4 chars in English
        // https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
        return Math.ceil(charactersAmount / 4)
    },

    getParagraphByIndex: ({ paragraphs, asrParagraphs }: Transcript, index: number): TranscriptParagraph | ASRTranscriptParagraph =>
        paragraphs[index] || asrParagraphs[index - paragraphs.length],

    getParagraphById: ({ paragraphs, asrParagraphs }: Transcript, id: string) =>
        paragraphs.find((p) => p.id === id) || asrParagraphs.find((p) => p.id === id),

    getParagraphIndexById: ({ paragraphs, asrParagraphs }: Transcript, id: string) => {
        let pIndex = paragraphs.findIndex((p) => p.id === id)

        if (pIndex === -1) {
            pIndex = asrParagraphs.findIndex((p) => p.id === id)

            if (pIndex !== -1) {
                pIndex += paragraphs.length
            }
        }

        return pIndex
    },

    removeParagraphAtIndex({ paragraphs, asrParagraphs }: Transcript, index: number): TranscriptParagraph | ASRTranscriptParagraph {
        if (paragraphs[index]) {
            return paragraphs.splice(index, 1)[0]
        } else {
            const normalizedAsrIndex = index - paragraphs.length

            return asrParagraphs.splice(normalizedAsrIndex, 1)[0]
        }
    },

    sliceParagraphs(transcript: Transcript, startIndex: number, endIndex: number) {
        const slice = []

        for (let i = startIndex; i < endIndex; i++) {
            slice.push(Transcript.getParagraphByIndex(transcript, i))
        }

        return slice
    },

    getAuthor(transcript: Transcript, authorId: string): TranscriptAuthor {
        return transcript.authors[authorId]
    },

    isAsrParagraph: (p: TranscriptParagraph | ASRTranscriptParagraph): p is ASRTranscriptParagraph =>
        !!(p as ASRTranscriptParagraph).uttIds,

    merge: (base: Transcript, revision: Transcript, preserveDiff?: boolean) =>
        produce(base, (draft) => {
            let diffStartIndex: number | undefined
            let diffEndIndexOld = 0
            let diffEndIndexNew = 0
            let isLastResolvedParagraphMutated = false

            for (const revisionParagraph of revision.paragraphs) {
                if (draft.words[revisionParagraph.id]) {
                    draft.words[revisionParagraph.id] = draft.words[revisionParagraph.id].concat(revision.words[revisionParagraph.id])
                } else {
                    draft.words[revisionParagraph.id] = revision.words[revisionParagraph.id]
                }

                let insertAtIndex = findLastIndex({ id: revisionParagraph.id }, draft.paragraphs)

                // if non-existent (yet)
                if (insertAtIndex === -1) {
                    // if this paragraph overlaps an existing paragraph
                    const lastParagraphs = last(draft.paragraphs)
                    if (lastParagraphs && revisionParagraph.time <= lastParagraphs.time) {
                        insertAtIndex = draft.paragraphs.length
                        while (draft.paragraphs[insertAtIndex - 1] && revisionParagraph.time <= draft.paragraphs[insertAtIndex - 1].time) {
                            insertAtIndex--
                        }
                        draft.paragraphs.splice(insertAtIndex, 0, revisionParagraph)
                    }
                    // if this paragraph is new chronologically
                    else {
                        draft.paragraphs.push(revisionParagraph)
                        insertAtIndex = draft.paragraphs.length - 1
                    }
                }
                // if existent
                else if (!isEqual(draft.paragraphs[insertAtIndex], revisionParagraph)) {
                    revisionParagraph.words = [...draft.paragraphs[insertAtIndex].words, ...revisionParagraph.words]
                    draft.paragraphs[insertAtIndex] = revisionParagraph
                }

                if (insertAtIndex !== undefined) {
                    const { isLastModified, removeCount } = Transcript.resolveOverlappedParagraphs(draft, insertAtIndex)

                    if (diffStartIndex === undefined) {
                        diffStartIndex = insertAtIndex
                        diffEndIndexOld = diffStartIndex
                        diffEndIndexNew = diffStartIndex
                    }

                    // TODO: check what happens when p1 and p2 are not sequential (p1, ...[pA, pB, ...], p2)
                    diffEndIndexOld += removeCount
                    diffEndIndexNew += 1
                    isLastResolvedParagraphMutated = isLastModified
                }
            }

            /** calculate the diff between the prev content before merging the revision to the next content after merging the revision **/
            if (diffStartIndex !== undefined) {
                if (isLastResolvedParagraphMutated) {
                    diffEndIndexOld++
                    diffEndIndexNew++
                }

                const removed = Transcript.sliceParagraphs(
                    base,
                    diffStartIndex,
                    Math.min(Transcript.getTotalParagraphCount(base), diffEndIndexOld),
                )

                const added = Transcript.sliceParagraphs(
                    draft,
                    diffStartIndex,
                    Math.min(Transcript.getTotalParagraphCount(draft), diffEndIndexNew),
                )

                if (preserveDiff && draft.lastRevisionUpdateDiff) {
                    draft.lastRevisionUpdateDiff.removed.push(...removed)
                    draft.lastRevisionUpdateDiff.added.push(...added)
                } else {
                    draft.lastRevisionUpdateDiff = {
                        removed,
                        added,
                    }
                }
            }

            for (const [authorId, author] of Object.entries(revision.authors)) {
                if (!isEqual(draft.authors[authorId], author)) {
                    draft.authors[authorId] = author
                }
            }

            draft.processedTime = Math.max(draft.processedTime, revision.processedTime)
        }),

    /**
     * Calculates the exact timestamp for the processed_time.
     * @param transcript Transcript
     * @returns a unix timestamp for the processed_time
     */
    getProcessedTime(transcript: Transcript): number {
        let start = new Date(new Date().getTime() - 4 * 60000).getTime()
        if (transcript.recordingStartedAt && typeof transcript.processedTime === 'number') {
            start = new Date(transcript.recordingStartedAt.getTime() + transcript.processedTime * 1000).getTime()
        }

        return start
    },

    resolveOverlappedParagraphs(transcript: Transcript, startIndex: number) {
        let index = startIndex
        let currentParagraph = Transcript.getParagraphByIndex(transcript, index)
        let currentParagraphLastWord = last(transcript.words[currentParagraph.id])
        let nextParagraph = Transcript.getParagraphByIndex(transcript, index + 1)
        let nextParagraphWords = transcript.words[nextParagraph?.id]
        let isLastModified = false
        let removeCount = 0

        /** iterate word by word (crossing paragraphs) and omit every word in the way that overlaps with the paragraph at startIndex **/
        while (currentParagraphLastWord && nextParagraphWords?.[0] && nextParagraphWords[0].time <= currentParagraphLastWord.time) {
            nextParagraphWords.shift()
            isLastModified = true

            /** if current paragraph is empty, omit it and move to the next paragraph **/
            if (!nextParagraphWords.length || !sumBy('duration', nextParagraphWords)) {
                const removedP = Transcript.removeParagraphAtIndex(transcript, startIndex + 1)
                delete transcript.words[removedP.id]

                nextParagraph = Transcript.getParagraphByIndex(transcript, index + 1)
                nextParagraphWords = transcript.words[nextParagraph?.id]
                removeCount++
                isLastModified = false
            }
        }

        if (nextParagraphWords) {
            const firstWord = nextParagraphWords[0]
            const lastWord = last(nextParagraphWords)

            if (lastWord) {
                /** update the last mutated paragraph properties **/
                nextParagraph.words = nextParagraphWords
                nextParagraph.updatedAt = Date.now()
                nextParagraph.time = firstWord.time
                nextParagraph.duration = lastWord.time + lastWord.duration - firstWord.time
            }
        }

        return { isLastModified, removeCount }
    },

    mergeAsrChunk: (currentTranscript: Transcript, asrChunkJSON: ASRChunkJSON, preserveDiff?: boolean): Transcript =>
        produce(currentTranscript, (nextTranscript) => {
            let firstModifiedIndex = nextTranscript.asrParagraphs.length - 1
            let p = nextTranscript.asrParagraphs[firstModifiedIndex]

            /** clean all the data (paragraphs & words) that are related to the utterance of asrChunkJSON
             *  because each chunk contains all the most updated content for that utterance **/
            while (p?.uttIds.includes(asrChunkJSON.utt_id)) {
                const words = nextTranscript.words[p.id]
                p.uttIds = words.reduce<number[]>(
                    (acc, w) => (!acc.includes((w as ASRTranscriptWord).uttId) ? acc.concat((w as ASRTranscriptWord).uttId) : acc),
                    [],
                )

                /** if all the words of this paragraph are from that utterance, remove the paragraph along with it's words **/
                if (p.uttIds.length === 1) {
                    nextTranscript.asrParagraphs.pop()
                    delete nextTranscript.words[p.id]
                } else {
                    /** if some of the words of this paragraph are from that utterance, remove only these words **/
                    nextTranscript.words[p.id] = nextTranscript.words[p.id].filter(
                        (w) => (w as ASRTranscriptWord).uttId !== asrChunkJSON.utt_id,
                    )

                    const pLastWord = last(nextTranscript.words[p.id])!
                    const pEndTime = pLastWord.time + pLastWord.duration

                    p.uttIds = p.uttIds.filter((uttId) => uttId !== asrChunkJSON.utt_id)
                    p.duration = pEndTime - p.time
                    p.words = nextTranscript.words[p.id]
                    p.updatedAt = Date.now()
                }

                firstModifiedIndex--
                p = nextTranscript.asrParagraphs[firstModifiedIndex]
            }

            firstModifiedIndex = Math.max(0, firstModifiedIndex)

            /** process asrChunkJSON words **/
            asrChunkJSON.word_list.forEach((w) => {
                const lastRevisionParagraph = last(nextTranscript.paragraphs)
                const lastRevisionWord = last(nextTranscript.words[lastRevisionParagraph?.id ?? ''])
                const lastAsrParagraph = last(nextTranscript.asrParagraphs)
                const lastAsrWord = last(nextTranscript.words[lastAsrParagraph?.id ?? ''])
                const lastRevisionWordEndTime = (lastRevisionWord?.time ?? 0) + (lastRevisionWord?.duration ?? 0)
                const revisionSectionEndTime = Math.max(nextTranscript.processedTime, lastRevisionWordEndTime)
                const asrSectionEndTime = (lastAsrWord?.time ?? 0) + (lastAsrWord?.duration ?? 0)

                if (
                    w.start < Math.max(revisionSectionEndTime, asrSectionEndTime) ||
                    (w.start === revisionSectionEndTime && isPunctuation(w.text))
                ) {
                    return
                }

                /** if this word belongs to the last paragraph, add it to it's words array **/
                if (w.paragraph_id === lastAsrParagraph?.id) {
                    nextTranscript.words[w.paragraph_id].push(TranscriptWord.fromAsrChunkWordJSON(w, asrChunkJSON.utt_id))

                    lastAsrParagraph.authorId = w.speaker_id
                    lastAsrParagraph.uttIds = uniq([...lastAsrParagraph.uttIds, asrChunkJSON.utt_id])
                    lastAsrParagraph.duration = w.end - nextTranscript.asrParagraphs[0].time
                    lastAsrParagraph.words = nextTranscript.words[w.paragraph_id]
                    lastAsrParagraph.updatedAt = Date.now()
                } else {
                    /** if this word doesn't belong to the last paragraph, create a new paragraph with this word and add it to the paragraphs array **/

                    nextTranscript.words[w.paragraph_id] = [TranscriptWord.fromAsrChunkWordJSON(w, asrChunkJSON.utt_id)]

                    nextTranscript.asrParagraphs.push({
                        id: w.paragraph_id,
                        uttIds: [asrChunkJSON.utt_id],
                        authorId: w.speaker_id,
                        time: w.start,
                        timeAbsolute: asrChunkJSON.start_epoch,
                        duration: w.end - w.start,
                        words: nextTranscript.words[w.paragraph_id],
                        updatedAt: Date.now(),
                    })
                }
            })

            nextTranscript.asrProcessedTime = asrChunkJSON.end

            /** calculate the diff between the prev content before merging the chunk to the next content after merging the chunk **/
            if (currentTranscript) {
                const removed = currentTranscript.asrParagraphs.slice(firstModifiedIndex, currentTranscript.asrParagraphs.length)
                const added = nextTranscript.asrParagraphs.slice(firstModifiedIndex, nextTranscript.asrParagraphs.length)

                if (preserveDiff && nextTranscript.lastAsrUpdateDiff) {
                    nextTranscript.lastAsrUpdateDiff.removed.push(...removed)
                    nextTranscript.lastAsrUpdateDiff.added.push(...added)
                } else {
                    nextTranscript.lastAsrUpdateDiff = {
                        removed,
                        added,
                    }
                }
            }
        }),

    mergeAsrTranscriptionHistory: (transcript: Transcript, asrUpdateStreamMessages: ASRUpdateStreamMessage[]) =>
        asrUpdateStreamMessages
            .filter((item) => item.data.utt_ended)
            .reverse()
            .reduce((accTranscript, { data: asrChunkJSON }) => Transcript.mergeAsrChunk(accTranscript, asrChunkJSON), transcript),
}
