import {OperatorBase} from "@app/textures/texture-editor/operator-stack/operators/abstract-base/operator-base"
import * as TextureEditNodes from "@app/textures/texture-editor/texture-edit-nodes"
import {
    Operator,
    OperatorFlags,
    OperatorInput,
    OperatorOutput,
    OperatorPanelComponentType,
    OperatorParameterValue,
} from "@app/textures/texture-editor/operator-stack/operators/abstract-base/operator"
import {OperatorCallback} from "@app/textures/texture-editor/operator-stack/operators/abstract-base/operator-callback"
import {deepCopy} from "@cm/lib/utils/utils"
import {ImageOpNodeGraphEvaluator} from "@app/textures/texture-editor/operator-stack/image-op-system/image-op-node-graph-evaluator"
import {TextureType} from "@api"
import {TilingPanelComponent} from "@app/textures/texture-editor/operator-stack/operators/tiling/panel/tiling-panel.component"
import {TilingToolbox} from "@app/textures/texture-editor/operator-stack/operators/tiling/toolbox/tiling-toolbox"
import {Hotkeys} from "@common/services/hotkeys/hotkeys.service"
import {BehaviorSubject, filter, merge, Observable, Subject, takeUntil} from "rxjs"
import {Vector2, Vector2Like} from "@cm/lib/math/vector2"

import {ImagePtr} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ref"
import {DebugImage} from "@app/textures/texture-editor/operator-stack/image-op-system/util/debug-image"
import {blend} from "@app/textures/texture-editor/operator-stack/image-op-system/nodes/image-op-nodes/blend-node"
import {copyRegion} from "@app/textures/texture-editor/operator-stack/image-op-system/nodes/image-op-nodes/copy-region-node"
import {lambda} from "@app/textures/texture-editor/operator-stack/image-op-system/nodes/basic-nodes/lambda-node"
import {gridMapping} from "@app/textures/texture-editor/operator-stack/image-op-system/nodes/image-op-nodes/grid-mapping-node"
import {ParameterValue} from "@cm/lib/graph-system/node-graph"
import {Context} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/context"
import {
    BoundaryDirection,
    BoundarySide,
    ControlPointType,
    ViewMode,
} from "@app/textures/texture-editor/operator-stack/operators/tiling/toolbox/tiling-area-toolbox-item"
import {hierarchicalCrossCorrelation} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/hierarchical-cross-correlation"
import {crossCorrelate} from "@app/textures/texture-editor/operator-stack/operators/tiling/helpers/cross-correlation"
import {Color} from "@cm/lib/math/color"
import {TextureEditorSettings} from "@app/textures/texture-editor/texture-editor-settings"
import {colorGradient} from "@app/textures/texture-editor/operator-stack/image-op-system/nodes/image-op-nodes/color-gradient-node"
import {GridPoint} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/image-op-grid-mapping"

export class OperatorTiling extends OperatorBase<TextureEditNodes.OperatorTiling> {
    // OperatorBase
    override readonly flags = new Set<OperatorFlags>(["no-clone", "no-disable", "apply-to-all-texture-types"])

    readonly panelComponentType: OperatorPanelComponentType = TilingPanelComponent
    readonly canvasToolbox: TilingToolbox

    readonly type = "operator-tiling" as const

    readonly showGuides$ = new BehaviorSubject(true)
    readonly viewMode$ = new BehaviorSubject<ViewMode>(ViewMode.Source)
    readonly snapEnabled$ = new BehaviorSubject(true)
    readonly snapDistanceInPixels$ = new BehaviorSubject(64)
    readonly fineAdjustSpacingInPixels$ = new BehaviorSubject(64)
    readonly fineAdjustSearchSizeRatio$ = new BehaviorSubject(0.25)
    readonly fineAdjustMinCorrelation$ = new BehaviorSubject(0.25)
    readonly blendDistanceInPixels$ = new BehaviorSubject(0)
    readonly debugDrawEnabled$ = new BehaviorSubject(false)

    constructor(callback: OperatorCallback, node: TextureEditNodes.OperatorTiling | null) {
        super(
            callback,
            deepCopy(node) ?? {
                type: "operator-tiling",
                enabled: true,
            },
        )

        this.canvasToolbox = new TilingToolbox(this)
        this.debugImage = new DebugImage(this.callback.imageOpContextWebGL2)

        const applyPipe = <T>(obs: Observable<T>) =>
            obs.pipe(
                takeUntil(this.destroyed),
                filter(() => this.callback.selectedOperator === this),
            )

        const hotkeys = this.callback.injector.get(Hotkeys)
        applyPipe(hotkeys.addShortcut(["k"])).subscribe(() => this.viewMode$.next(ViewMode.Source))
        applyPipe(hotkeys.addShortcut(["l"])).subscribe(() => this.viewMode$.next(ViewMode.Result))
        applyPipe(hotkeys.addShortcut(["s"])).subscribe(() => this.snapEnabled$.next(!this.snapEnabled$.value))
        applyPipe(hotkeys.addShortcut(["v"])).subscribe(() => this.showGuides$.next(!this.showGuides$.value))

        applyPipe(this.showGuides$).subscribe((value) => (this.canvasToolbox.tilingArea.visible = value))
        applyPipe(this.viewMode$).subscribe((viewMode) => {
            this.canvasToolbox.tilingArea.viewMode$.next(viewMode)
            this.requestEval()
        })
        applyPipe(merge(this.snapDistanceInPixels$, this.snapEnabled$)).subscribe(() =>
            this.canvasToolbox.tilingArea.snapDistance$.next(this.snapEnabled$.value ? this.snapDistanceInPixels$.value : 0),
        )
        applyPipe(this.debugDrawEnabled$).subscribe(() => this.requestEval())
        applyPipe(this.blendDistanceInPixels$).subscribe(() => this.onBlendDistanceInPixelsChanged())
        applyPipe(this.canvasToolbox.tilingArea.changed$).subscribe(() => this.onTilingAreaChanged())
    }

    // OperatorBase
    override dispose(): void {
        this.destroyed.next()
        this.destroyed.complete()
        super.dispose()
        this.canvasToolbox.remove()
    }

    // OperatorBase
    async clone(): Promise<Operator> {
        return new OperatorTiling(this.callback, deepCopy(this.node))
    }

    // OperatorBase
    async getImageOpNodeGraph(evaluator: ImageOpNodeGraphEvaluator, textureType: TextureType, input: OperatorInput): Promise<OperatorOutput> {
        let resultImage: ParameterValue<ImagePtr, Context>
        if (evaluator.mode === "preview" && this.selected && this.viewMode$.value === ViewMode.Source) {
            resultImage = lambda({sourceImage: input}, async ({parameters: {sourceImage}}) => new ImagePtr(sourceImage))
        } else {
            // mapping
            // const pixelsPerSample = 100
            // const boundaryMappedLengthH = this.canvasToolbox.tilingArea.getBoundaryMappedLength(BoundaryDirection.Horizontal)
            // const boundaryMappedLengthV = this.canvasToolbox.tilingArea.getBoundaryMappedLength(BoundaryDirection.Vertical)
            // const numStepsU = Math.ceil(boundaryMappedLengthH / pixelsPerSample)
            // const numStepsV = Math.ceil(boundaryMappedLengthV / pixelsPerSample)
            const subdivisionsPerSegment = 16
            const numStepsU = Math.ceil((this.canvasToolbox.tilingArea.getBoundaryNumControlPoints(BoundaryDirection.Horizontal) - 1) * subdivisionsPerSegment)
            const numStepsV = Math.ceil((this.canvasToolbox.tilingArea.getBoundaryNumControlPoints(BoundaryDirection.Vertical) - 1) * subdivisionsPerSegment)
            if (!this.tessellatedGridPoints) {
                this.tessellatedGridPoints = this.canvasToolbox.tilingArea.computeGridPoints(
                    {numSteps: numStepsU, tMin: 0, tMax: 1},
                    {numSteps: numStepsV, tMin: 0, tMax: 1},
                )
            }
            resultImage = gridMapping({
                sourceImage: input,
                gridPoints: this.tessellatedGridPoints,
            })

            // border blending
            const blendDistanceInPixels = Math.round(this.blendDistanceInPixels$.value)
            if (blendDistanceInPixels > 0) {
                // extract borders
                const extractBorder = (sourceImage: OperatorParameterValue<ImagePtr>, direction: BoundaryDirection, side: BoundarySide) => {
                    // compute border grid points
                    const tValues = this.canvasToolbox.tilingArea.computeBoundaryTValues(direction, side, this.blendDistanceInPixels$.value)
                    const borderGridPoints = this.canvasToolbox.tilingArea.computeGridPoints(
                        {numSteps: direction === BoundaryDirection.Horizontal ? numStepsU : 2, tMin: tValues.tMinU, tMax: tValues.tMaxU},
                        {numSteps: direction === BoundaryDirection.Horizontal ? 2 : numStepsV, tMin: tValues.tMinV, tMax: tValues.tMaxV},
                    )
                    // shift border grid points to origin
                    const offset = new Vector2(Number.POSITIVE_INFINITY, Number.POSITIVE_INFINITY)
                    borderGridPoints.forEach((row) =>
                        row.forEach((point) => offset.set(Math.min(offset.x, point.targetPixel.x), Math.min(offset.y, point.targetPixel.y))),
                    )
                    borderGridPoints.forEach((row) =>
                        row.forEach((point) => (point.targetPixel = Vector2.fromVector2Like(point.targetPixel).subInPlace(offset))),
                    )
                    // map border
                    return gridMapping({
                        sourceImage,
                        gridPoints: borderGridPoints,
                    })
                }
                const topBorder = extractBorder(input, BoundaryDirection.Horizontal, BoundarySide.Low)
                const bottomBorder = extractBorder(input, BoundaryDirection.Horizontal, BoundarySide.High)
                const leftBorder = extractBorder(input, BoundaryDirection.Vertical, BoundarySide.Low)
                const rightBorder = extractBorder(input, BoundaryDirection.Vertical, BoundarySide.High)
                // blend by gradient
                const blendBorder = (
                    resultImage: OperatorParameterValue<ImagePtr>,
                    borderImage: OperatorParameterValue<ImagePtr>,
                    direction: BoundaryDirection,
                    side: BoundarySide,
                ) => {
                    // compute blending gradient
                    const gradientImage = colorGradient({
                        descriptor: lambda({sourceImage: borderImage}, async ({context, parameters: {sourceImage}}) => {
                            const descriptor = await context.getImageDescriptor(sourceImage)
                            return {
                                width: descriptor.width,
                                height: descriptor.height,
                                channelLayout: "R",
                                format: TextureEditorSettings.PreviewProcessingImageFormat,
                                isSRGB: false,
                            }
                        }),
                        type: "linear",
                        startPos: new Vector2(0, 0),
                        endPos: lambda({sourceImage: borderImage}, async ({context, parameters: {sourceImage}}) => {
                            const descriptor = await context.getImageDescriptor(sourceImage)
                            if (direction === BoundaryDirection.Vertical) {
                                return {x: descriptor.width, y: 0}
                            } else {
                                return {x: 0, y: descriptor.height}
                            }
                        }),
                        stops: [
                            {t: 0, color: side === BoundarySide.High ? new Color(0) : new Color(0.5)},
                            {t: 1, color: side === BoundarySide.High ? new Color(0.5) : new Color(0)},
                        ],
                    })
                    // compute target offset
                    const targetRegion = lambda({resultImage}, async ({context, parameters: {resultImage}}) => {
                        const descriptor = await context.getImageDescriptor(resultImage)
                        if (side === BoundarySide.High) {
                            if (direction === BoundaryDirection.Vertical) {
                                return {x: descriptor.width - blendDistanceInPixels, y: 0, width: blendDistanceInPixels, height: descriptor.height}
                            } else {
                                return {x: 0, y: descriptor.height - blendDistanceInPixels, width: descriptor.width, height: blendDistanceInPixels}
                            }
                        } else {
                            if (direction === BoundaryDirection.Vertical) {
                                return {x: 0, y: 0, width: blendDistanceInPixels, height: descriptor.height}
                            } else {
                                return {x: 0, y: 0, width: descriptor.width, height: blendDistanceInPixels}
                            }
                        }
                    })
                    // cut out target region from result
                    const cutOutResult = copyRegion({
                        sourceImage: resultImage,
                        sourceRegion: targetRegion,
                    })
                    // blend border
                    const blendedCutOutResult = blend({
                        backgroundImage: cutOutResult,
                        foregroundImage: borderImage,
                        alpha: gradientImage,
                        premultipliedAlpha: false,
                        blendMode: "normal",
                    })
                    // copy back to result
                    resultImage = copyRegion({
                        sourceImage: blendedCutOutResult,
                        targetOffset: targetRegion,
                        resultImage,
                    })
                    return resultImage
                }
                resultImage = blendBorder(resultImage, bottomBorder, BoundaryDirection.Horizontal, BoundarySide.Low)
                resultImage = blendBorder(resultImage, topBorder, BoundaryDirection.Horizontal, BoundarySide.High)
                resultImage = blendBorder(resultImage, rightBorder, BoundaryDirection.Vertical, BoundarySide.Low)
                resultImage = blendBorder(resultImage, leftBorder, BoundaryDirection.Vertical, BoundarySide.High)
            }
        }
        if (this.debugImage && this.debugDrawEnabled$.value) {
            const blendDebugImage = false
            if (blendDebugImage) {
                // alpha blend the debug image on top
                const cutOut = copyRegion({
                    sourceImage: resultImage,
                    sourceRegion: lambda({debugImage: this.debugImage.imageRef}, async ({context, parameters: {debugImage}}) => {
                        const descriptor = await context.getImageDescriptor(debugImage)
                        return {x: 0, y: 0, width: descriptor.width, height: descriptor.height}
                    }),
                })
                const blendedCutOut = blend({
                    backgroundImage: cutOut,
                    foregroundImage: this.debugImage.imageRef,
                    premultipliedAlpha: false,
                    blendMode: "normal",
                })
                return {
                    resultImage: copyRegion({
                        sourceImage: blendedCutOut,
                        resultImage,
                    }),
                }
            } else {
                const sourceCopy = copyRegion({
                    sourceImage: resultImage,
                })
                const debugCopy = copyRegion({
                    sourceImage: this.debugImage.imageRef,
                    resultImage: sourceCopy,
                })
                return {resultImage: debugCopy}
            }
        } else {
            return {resultImage}
        }
    }

    async executeGridFineAdjustment() {
        const tilingArea = this.canvasToolbox.tilingArea

        const adjustBoundary = async (boundaryDirection: BoundaryDirection) => {
            const snapDistance = (this.fineAdjustSpacingInPixels$.value / 2) * this.fineAdjustSearchSizeRatio$.value
            const correlationWindowSize = 64
            const searchSize = Math.ceil(snapDistance * 2)
            const controlPointsLow = tilingArea.getBoundaryPoints(boundaryDirection, BoundarySide.Low)
            const controlPointsHigh = tilingArea.getBoundaryPoints(boundaryDirection, BoundarySide.High)
            for (let i = 1; i < controlPointsLow.length; i++) {
                const posLowPrev = controlPointsLow[i - 1]
                const posLowNext = controlPointsLow[i]
                const posHighPrev = controlPointsHigh[i - 1]
                const posHighNext = controlPointsHigh[i]
                const tPrev = (posLowPrev.t + posHighPrev.t) / 2
                const tNext = (posLowNext.t + posHighNext.t) / 2
                const sourcePositionLowDelta = posLowNext.sourcePosition.sub(posLowPrev.sourcePosition)
                const sourcePositionHighDelta = posHighNext.sourcePosition.sub(posHighPrev.sourcePosition)
                const sourcePositionAvgLength = (sourcePositionLowDelta.norm() + sourcePositionHighDelta.norm()) / 2
                const numPointsToInsert = Math.floor(sourcePositionAvgLength / this.fineAdjustSpacingInPixels$.value)
                for (let j = 1; j <= numPointsToInsert; j++) {
                    const interpolator = j / (numPointsToInsert + 1)
                    const t = tPrev * (1 - interpolator) + tNext * interpolator
                    const posMin = posLowPrev.sourcePosition.add(sourcePositionLowDelta.mul(interpolator))
                    const posMax = posHighPrev.sourcePosition.add(sourcePositionHighDelta.mul(interpolator))
                    const correlationResult = await this.correlatePoints(posMax, posMin, correlationWindowSize, searchSize)
                    if (correlationResult && correlationResult.correlation >= this.fineAdjustMinCorrelation$.value) {
                        tilingArea.insertBoundaryPoints(boundaryDirection, ControlPointType.FineAdjustment, t, posMin, correlationResult.position)
                    }
                }
            }
        }

        this.removeGridFineAdjustment()
        await Promise.all([adjustBoundary(BoundaryDirection.Horizontal), adjustBoundary(BoundaryDirection.Vertical)])
    }

    removeGridFineAdjustment() {
        this.canvasToolbox.tilingArea.removeControlPointsOfType(ControlPointType.FineAdjustment)
    }

    private onBlendDistanceInPixelsChanged() {
        this.canvasToolbox.tilingArea.displayBorderWidthInPixels$.next(this.blendDistanceInPixels$.value)
        this.requestEval()
    }

    private onTilingAreaChanged() {
        this.tessellatedGridPoints = undefined
        this.requestEval()
    }

    private async correlatePoints(
        position: Vector2Like,
        referencePosition: Vector2Like,
        correlationWindowSize: number,
        searchSize: number,
        debugImage?: DebugImage,
    ) {
        const imageOpContextWebGL2 = this.callback.imageOpContextWebGL2
        using sourceImage = new ImagePtr(this.callback.selectedOperatorInput)
        if (!sourceImage) {
            return undefined
        }
        const sourceRegion = {
            x: Math.round(position.x - (correlationWindowSize + searchSize) / 2),
            y: Math.round(position.y - (correlationWindowSize + searchSize) / 2),
            width: correlationWindowSize + searchSize - 1,
            height: correlationWindowSize + searchSize - 1,
        }
        const templateRegion = {
            x: Math.round(referencePosition.x - correlationWindowSize / 2),
            y: Math.round(referencePosition.y - correlationWindowSize / 2),
            width: correlationWindowSize,
            height: correlationWindowSize,
        }
        const {peakOffset, peakValue} = await crossCorrelate(imageOpContextWebGL2, sourceImage, sourceRegion, sourceImage, templateRegion, debugImage)
        peakOffset.addInPlace(sourceRegion)
        peakOffset.addInPlace({x: correlationWindowSize / 2, y: correlationWindowSize / 2})
        return {position: peakOffset, correlation: peakValue}
    }

    async computeSnapPosition(position: Vector2Like, referencePosition: Vector2Like): Promise<Vector2 | undefined> {
        const debugImage = this.debugDrawEnabled$.value ? this.debugImage : undefined
        await debugImage?.init({width: 6000, height: 8192})
        if (!this.snapEnabled$.value) {
            return undefined
        }
        const snapDistance = Math.round(this.snapDistanceInPixels$.value)
        if (snapDistance <= 0) {
            return undefined
        }
        const correlationWindowSize = 64
        const searchSize = Math.ceil(snapDistance * 2)
        const correlationResult = await this.correlatePoints(position, referencePosition, correlationWindowSize, searchSize, debugImage)
        if (debugImage) {
            this.requestEval()
        }
        return correlationResult?.position
    }

    async computeSnapPosition_(position: Vector2Like, referencePosition: Vector2Like): Promise<Vector2 | undefined> {
        this.canvasToolbox.clearDebugRects()
        const debugImage = this.debugDrawEnabled$.value ? this.debugImage : undefined

        if (!this.snapEnabled$.value) {
            return undefined
        }
        const snapDistance = Math.round(this.snapDistanceInPixels$.value)
        if (snapDistance <= 0) {
            return undefined
        }
        const correlationWindowSize = 8 // must be even
        const searchSize = 3

        const imageOpContextWebGL2 = this.callback.imageOpContextWebGL2
        const sourceImageRef = this.callback.selectedOperatorInput
        if (!sourceImageRef) {
            return undefined
        }

        await debugImage?.init({width: 512, height: 4096})

        const numLevels = Math.ceil(Math.log2(snapDistance))
        const maxTemplateSize = 2 ** (numLevels - 1) * correlationWindowSize
        const maxSourceImageSize = 2 ** (numLevels - 1) * (correlationWindowSize + searchSize - 1)

        // compute regions
        const templateRegion = {
            x: Math.round(referencePosition.x - maxTemplateSize / 2),
            y: Math.round(referencePosition.y - maxTemplateSize / 2),
            width: maxTemplateSize,
            height: maxTemplateSize,
        }
        if (this.debugDrawEnabled$.value) {
            this.canvasToolbox.createDebugRect(templateRegion, "green")
        }
        const sourceRegion = {
            x: Math.round(position.x - maxSourceImageSize / 2),
            y: Math.round(position.y - maxSourceImageSize / 2),
            width: maxSourceImageSize,
            height: maxSourceImageSize,
        }
        if (this.debugDrawEnabled$.value) {
            this.canvasToolbox.createDebugRect(sourceRegion, "blue")
        }

        // compute falloff image, if needed
        // const applyFalloff = false
        // if (applyFalloff) {
        //     if (!this.falloffImage) {
        //         this.falloffImage = await imageOpCreateImage.WebGL2({
        //             context: imageOpContextWebGL2,
        //             parameters: {
        //                 descriptor: {
        //                     width: maxSourceImageSize,
        //                     height: maxSourceImageSize,
        //                     channelLayout: "R",
        //                     format: TextureEditorSettings.PreviewProcessingImageFormat,
        //                     isSRGB: false,
        //                 },
        //                 fillColor: {r: 0, g: 0, b: 0, a: 1},
        //             },
        //         })
        //         const halGenerateFalloffImage = await imageOpContextWebGL2.getOrCreateImageCompositor(`
        //         uniform float u_falloff;
        //
        //         vec4 computeColor(ivec2 targetPixel) {
        //             float normDistance = max(0.0, 1.0 - length(vec2(targetPixel) / vec2(u_targetSize) - 0.5) * 2.0);
        //             float value = pow(normDistance, u_falloff * 2.0);
        //             return vec4(value);
        //         }
        //     `)
        //         halGenerateFalloffImage.setParameter("u_falloff", {type: "float", value: this.snapFalloff$.value})
        //         const halGenerateFalloffImageWebGL2 = await imageOpContextWebGL2.getImage(this.falloffImage)
        //         await halGenerateFalloffImage.paint(halGenerateFalloffImageWebGL2.ref.halImage)
        //         halGenerateFalloffImageWebGL2.release()
        //     }
        //     await debugImage?.addImage(this.falloffImage)
        // }

        const snappedPosition = hierarchicalCrossCorrelation(
            imageOpContextWebGL2,
            sourceImageRef,
            sourceRegion,
            sourceImageRef,
            templateRegion,
            correlationWindowSize,
            searchSize,
            this.debugDrawEnabled$.value ? debugImage : undefined,
            this.debugDrawEnabled$.value ? (rect, color) => this.canvasToolbox.createDebugRect(rect, color) : undefined,
        )

        if (this.debugDrawEnabled$.value) {
            this.requestEval()
        }

        return snappedPosition
    }

    private destroyed = new Subject<void>()
    private tessellatedGridPoints?: GridPoint[][]
    // private falloffImage?: ImageRef
    private debugImage: DebugImage
}
