import {DataType, ImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {extractPatches} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-extract-patches"
import {copyRegion, ExtRegion} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-copy-region"
import {math} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-math"
import {mean} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/mean"
import {reduce} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-reduce"
import {ImageOpCommandQueue} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-command-queue"
import {toGrayscale} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-to-grayscale"
import {DebugImage} from "@app/textures/texture-editor/operator-stack/image-op-system/util/debug-image"
import {reshape} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/utils/reshape"

const SCOPE_NAME = "NormalizedCrossCorrelation"

export type ParameterType = {
    sourceImage: ImageRef
    sourceWeightImage?: ImageRef
    sourceRegion?: ExtRegion // default: {x: 0, y: 0, width: sourceImage.width, height: sourceImage.height}; offsetImage can be a 1x1 >=2-channel image; offset is added to offsetImage
    templateImage: ImageRef
    templateWeightImage?: ImageRef
    templateRegion?: ExtRegion // default: {x: 0, y: 0, width: templateImage.width, height: templateImage.height}; offsetImage can be a 1x1 >=2-channel image; offset is added to offsetImage
    options?: {
        premultipliedImages?: boolean // default: false; if true, the images are assumed to be premultiplied by their weights
    }
    debugImage?: DebugImage
}

export type ReturnType = ImageRef

export const normalizedCrossCorrelation = (
    cmdQueue: ImageOpCommandQueue,
    {sourceImage, sourceWeightImage, sourceRegion, templateImage, templateWeightImage, templateRegion, options, debugImage}: ParameterType,
): ReturnType => {
    cmdQueue.beginScope(SCOPE_NAME)

    const resultDataType = sourceImage.descriptor.dataType
    const premultipliedImages = options?.premultipliedImages ?? false
    const intermediateDataType: DataType = resultDataType === "float32" ? "float32" : "float16"

    // some helper functions
    const makeMeanFree = (data: ImageRef, weights?: ImageRef) => {
        const dataMean = mean(cmdQueue, {
            sourceImage: data,
            sourceWeightImage: weights,
            resultDataType: intermediateDataType,
        })
        return math(cmdQueue, {
            operator: "-",
            operandA: data,
            operandB: dataMean.mean,
            resultImageOrDataType: intermediateDataType,
        })
    }
    const factorInWeightsAndReduce = (data: ImageRef, weights?: ImageRef) => {
        // factor in weights if they exist
        if (weights) {
            data = math(cmdQueue, {
                operator: "*",
                operandA: data,
                operandB: weights,
            })
        }
        // sum
        return reduce(cmdQueue, {
            sourceImage: data,
            operator: "sum",
            resultDataType: intermediateDataType,
        })
    }

    if (templateRegion) {
        templateImage = copyRegion(cmdQueue, {
            sourceImage: templateImage,
            sourceRegion: templateRegion,
            addressMode: "border",
        })
        if (templateWeightImage) {
            templateWeightImage = copyRegion(cmdQueue, {
                sourceImage: templateWeightImage,
                sourceRegion: templateRegion,
                addressMode: "border",
            })
        }
    }

    templateRegion ??= {
        x: 0,
        y: 0,
        width: templateImage.descriptor.width,
        height: templateImage.descriptor.height,
    }
    sourceRegion ??= {
        x: 0,
        y: 0,
        width: sourceImage.descriptor.width,
        height: sourceImage.descriptor.height,
    }
    const resultSize = {
        width: sourceRegion.width - templateRegion.width + 1,
        height: sourceRegion.height - templateRegion.height + 1,
    }
    if (resultSize.width <= 0 || resultSize.height <= 0) {
        throw new Error("Template image must be smaller than source image")
    }

    // TODO can we avoid this conversion ?
    if (premultipliedImages) {
        // remove pre-multiplication
        if (sourceWeightImage) {
            sourceImage = math(cmdQueue, {
                operator: "/safe",
                operandA: sourceImage,
                operandB: sourceWeightImage,
            })
        }
        if (templateWeightImage) {
            templateImage = math(cmdQueue, {
                operator: "/safe",
                operandA: templateImage,
                operandB: templateWeightImage,
            })
        }
    }

    debugImage?.addImage(templateImage)
    // if (templateWeightImage) {
    //     debugImage?.addImage(templateWeightImage)
    // }

    // extract the patches of the source image
    let sourcePatches = extractPatches(cmdQueue, {
        sourceImage,
        sourceRegion,
        patchSize: templateRegion,
        addressMode: "border",
    })
    const sourceWeightPatches = sourceWeightImage
        ? extractPatches(cmdQueue, {
              sourceImage: sourceWeightImage,
              sourceRegion,
              patchSize: templateRegion,
              addressMode: "border",
          })
        : undefined

    debugImage?.addImage(sourcePatches)
    // if (sourceWeightPatches) {
    //     debugImage?.addImage(sourceWeightPatches)
    // }

    // compute the weight product
    const weightProduct =
        sourceWeightPatches && templateWeightImage
            ? math(cmdQueue, {
                  operator: "*",
                  operandA: sourceWeightPatches,
                  operandB: templateWeightImage,
              })
            : sourceWeightPatches
              ? sourceWeightPatches
              : templateWeightImage
                ? templateWeightImage
                : undefined

    // if (weightProduct) {
    //     debugImage?.addImage(weightProduct)
    // }

    // remove mean from template
    templateImage = makeMeanFree(templateImage, weightProduct)
    // debugImage?.addImage(templateImage, {scale: 0.5, offset: 0.5})

    // remove mean from source patches
    sourcePatches = makeMeanFree(sourcePatches, weightProduct) // TODO we're computing the weightProductSum a 2nd time here
    // debugImage?.addImage(sourcePatches, {scale: 0.5, offset: 0.5})

    // compute the numerator of the NCC
    const numerator = factorInWeightsAndReduce(
        math(cmdQueue, {
            operator: "*",
            operandA: sourcePatches,
            operandB: templateImage,
        }),
        weightProduct,
    )

    // const scale = 500 /// templateRegion.width * templateRegion.height
    // debugImage?.addImage(numerator, {scale: scale, offset: 0.5})

    // compute the denominator of the NCC
    // source patch term
    const denominatorSourcePatchTerm = factorInWeightsAndReduce(
        math(cmdQueue, {
            operator: "square",
            operand: sourcePatches,
        }),
        weightProduct,
    )
    // template term
    const denominatorTemplateTerm = factorInWeightsAndReduce(
        math(cmdQueue, {
            operator: "square",
            operand: templateImage,
        }),
        weightProduct,
    )
    // compute final denominator
    const denominator = math(cmdQueue, {
        operator: "sqrt",
        operand: math(cmdQueue, {
            operator: "*",
            operandA: denominatorSourcePatchTerm,
            operandB: denominatorTemplateTerm,
        }),
    })

    // debugImage?.addImage(denominator, {scale: scale})

    // compute final ncc with safe division
    let crossCorrelation = math(cmdQueue, {
        operator: "/safe",
        operandA: numerator,
        operandB: denominator,
    })

    debugImage?.addImage(crossCorrelation)

    // if the source is an RGB image, reduce the correlation by averaging the individual channels
    if (sourceImage.descriptor.channelLayout !== "R") {
        crossCorrelation = toGrayscale(cmdQueue, {sourceImage: crossCorrelation, mode: "luminance"}) // weight by perceived brightness
        debugImage?.addImage(crossCorrelation)
    }

    // reshape the result to remove batch dimension
    crossCorrelation = reshape(crossCorrelation, {
        ...resultSize,
        batchSize: {width: 1, height: 1},
    })

    cmdQueue.endScope(SCOPE_NAME)
    return crossCorrelation
}
