import {ImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {Vector2, Vector2Like} from "@cm/math"
import {ImageOpContextWebGL2} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-context-webgl2"
import {rasterizeGeometry} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-rasterize-geometry"
import {normalizedCrossCorrelation} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/composite/normalized-cross-correlation"

export type SegmentType = {
    from: Vector2Like
    to: Vector2Like
    width: number
}

export type ParameterType = {
    sourceImage: ImageRef
    referenceSegment: SegmentType
    searchSegment: SegmentType
    penaltyFn?: (position: Vector2) => number
}

export type ReturnType = {
    referenceSegment: SegmentType
    searchSegment: SegmentType
    correlationData: Float32Array
    peakIndex: Vector2Like
    peakValue: number
    bestMatchPosition: Vector2
}

export const findBestMatchAlongSegment = async (
    imageOpContext: ImageOpContextWebGL2,
    {sourceImage, referenceSegment, searchSegment, penaltyFn}: ParameterType,
) => {
    const cmdQueue = imageOpContext.createCommandQueue()
    const referenceSegmentOrigin = Vector2.fromVector2Like(referenceSegment.from)
    const referenceSegmentEnd = Vector2.fromVector2Like(referenceSegment.to)
    const referenceSegmentDelta = referenceSegmentEnd.sub(referenceSegmentOrigin)
    const referenceSegmentDir = referenceSegmentDelta.normalized()
    const referenceSegmentLength = referenceSegmentDelta.norm()

    const referenceSegmentDirPerp = referenceSegmentDir.perp()
    const templateWidth = Math.round(referenceSegmentLength)
    const templateHeight = Math.round(referenceSegment.width)
    const templateImage = rasterizeGeometry(cmdQueue, {
        geometry: {
            topology: "triangleList",
            vertices: {
                positions: [new Vector2(0, 0), new Vector2(templateWidth, 0), new Vector2(templateWidth, templateHeight), new Vector2(0, templateHeight)],
                uvs: [
                    referenceSegmentOrigin.add(referenceSegmentDirPerp.mul(-templateHeight * 0.5)),
                    referenceSegmentOrigin.add(referenceSegmentDir.mul(templateWidth).add(referenceSegmentDirPerp.mul(-templateHeight * 0.5))),
                    referenceSegmentOrigin.add(referenceSegmentDir.mul(templateWidth).add(referenceSegmentDirPerp.mul(templateHeight * 0.5))),
                    referenceSegmentOrigin.add(referenceSegmentDirPerp.mul(templateHeight * 0.5)),
                ],
            },
            indices: [0, 1, 2, 0, 2, 3],
        },
        textureImage: sourceImage,
        resultImageOrDescriptor: {
            ...sourceImage.descriptor,
            width: templateWidth,
            height: templateHeight,
        },
    })
    const searchSegmentOrigin = Vector2.fromVector2Like(searchSegment.from)
    const searchSegmentEnd = Vector2.fromVector2Like(searchSegment.to)
    const searchSegmentDelta = searchSegmentEnd.sub(searchSegmentOrigin)
    const searchSegmentDir = searchSegmentDelta.normalized()
    const searchSegmentDirPerp = searchSegmentDir.perp()
    const searchSegmentLength = searchSegmentDelta.norm()
    const searchImageWidth = Math.round(searchSegmentLength)
    const searchImageHeight = Math.round(searchSegment.width)
    const resultSize = {
        x: searchImageWidth - templateWidth + 1,
        y: searchImageHeight - templateHeight + 1,
    }
    if (resultSize.x <= 0 || resultSize.y <= 0) {
        throw new Error("Search segment must be at least as long and wide as the template")
    }
    const searchImage = rasterizeGeometry(cmdQueue, {
        geometry: {
            topology: "triangleList",
            vertices: {
                positions: [
                    new Vector2(0, 0),
                    new Vector2(searchImageWidth, 0),
                    new Vector2(searchImageWidth, searchImageHeight),
                    new Vector2(0, searchImageHeight),
                ],
                uvs: [
                    searchSegmentOrigin.add(searchSegmentDirPerp.mul(-searchImageHeight * 0.5)),
                    searchSegmentOrigin.add(searchSegmentDir.mul(searchImageWidth).add(searchSegmentDirPerp.mul(-searchImageHeight * 0.5))),
                    searchSegmentOrigin.add(searchSegmentDir.mul(searchImageWidth).add(searchSegmentDirPerp.mul(searchImageHeight * 0.5))),
                    searchSegmentOrigin.add(searchSegmentDirPerp.mul(searchImageHeight * 0.5)),
                ],
            },
            indices: [0, 1, 2, 0, 2, 3],
        },
        textureImage: sourceImage,
        resultImageOrDescriptor: {
            ...sourceImage.descriptor,
            width: searchImageWidth,
            height: searchImageHeight,
        },
    })
    const correlation = normalizedCrossCorrelation(cmdQueue, {
        sourceImage: searchImage,
        templateImage: templateImage,
    })
    // execute
    const [correlationImage] = await cmdQueue.execute([correlation], {waitForCompletion: true})
    const correlationData = await correlationImage.ref.halImageView.resource.readRawImageData("float32")
    correlationImage.release()
    // search for best match
    if (correlation.descriptor.width !== resultSize.x || correlation.descriptor.height !== resultSize.y) {
        throw new Error("Unexpected correlation image size")
    }
    const computePositionFromIndex = (index: Vector2Like) =>
        searchSegmentOrigin.add(searchSegmentDir.mul(index.x).add(searchSegmentDirPerp.mul(index.y - (searchImageHeight - templateHeight) * 0.5)))
    let peakIndex = {x: 0, y: 0}
    let peakValue = Number.NEGATIVE_INFINITY
    for (let y = 0; y < resultSize.y; y++) {
        for (let x = 0; x < resultSize.x; x++) {
            const correlationIndex = y * resultSize.x + x
            const penalty = penaltyFn ? penaltyFn(computePositionFromIndex({x, y})) : 0
            if (penalty) {
                correlationData[correlationIndex] -= penalty
            }
            const value = correlationData[correlationIndex]
            if (value > peakValue) {
                peakIndex = {x, y}
                peakValue = value
            }
        }
    }
    // console.log(`Best peak along line at (${peakIndex.x}, ${peakIndex.y}) with value ${peakValue}`)
    const bestMatchPosition = computePositionFromIndex(peakIndex)
    return {
        referenceSegment,
        searchSegment,
        correlationData,
        peakIndex,
        peakValue,
        bestMatchPosition,
    }
}
