import {ChannelLayout, DataType, ImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {ImageOpType, runImageOp} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op"
import {assertNever} from "ts-lib/dist/browser/utils/utils"
import {HalPainterParameterValueType} from "@common/models/hal/hal-painter/types"
import {ImageOpCommandQueue} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-command-queue"
import {PainterRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/painter-ref"
import {convert} from "@app/textures/texture-editor/operator-stack/image-op-system/image-ops/primitive/image-op-convert"

export type ParameterType = {
    sourceImage: ImageRef
    operator: "sum" | "min" | "max" | "mean" | "sum-square" | "mean-square" | "root-mean-square" | "custom"
    resultDataType?: DataType // default: sourceImage.descriptor.dataType
    // only required for operator: "custom"; NOT SUPPORTED IN ImgProc !!!
    custom?: {
        reduce: ShaderDef // 2-tuple reduction function with signature: vec4 reduce(vec4 a, vec4 b)
        preProcess?: ShaderDef // pre-process function with signature: vec4 preProcess(vec4 value, ivec2 index)
        postProcess?: ShaderDef // post-process function with signature: vec4 postProcess(vec4 value, ivec2 index)
    }
}

export type ReturnType = ImageRef

export type ShaderDef = {
    name: string
    code: string
    parameters?: {[key: string]: HalPainterParameterValueType}
    textures?: {[key: number]: ImageRef | undefined}
    resultChannelLayout?: ChannelLayout // default: sourceImage.descriptor.channelLayout
}

type CompiledShaderDef = {
    painter: PainterRef
    parameters?: {[key: string]: HalPainterParameterValueType}
    additionalTextures: ImageRef[]
    resultChannelLayout: ChannelLayout
}

export const imageOpReduce: ImageOpType<ParameterType, ReturnType> = {
    name: "Reduce",

    WebGL2: ({cmdQueue, parameters: {sourceImage, operator, resultDataType, custom}}) => {
        if (operator === "custom" && !custom) {
            throw new Error("Custom reduce operator requires custom parameters")
        }
        resultDataType ??= sourceImage.descriptor.dataType
        const intermediateDataType = resultDataType === "float32" ? "float32" : "float16"
        const sourceImageChannelLayout = sourceImage.descriptor.channelLayout
        const createShader = (shaderDef: ShaderDef): CompiledShaderDef => {
            const shader = cmdQueue.createPainter("compositor", shaderDef.name, shaderDef.code)
            const additionalTextures: ImageRef[] = []
            if (shaderDef.textures) {
                for (const [index, imageRef] of Object.entries(shaderDef.textures)) {
                    if (imageRef) {
                        const indexInt = parseInt(index)
                        if (indexInt <= 0) {
                            throw new Error("Texture index must be greater than 0")
                        }
                        additionalTextures.length = Math.max(additionalTextures.length, indexInt + 1)
                        additionalTextures[indexInt] = imageRef
                    }
                }
                if (additionalTextures.length > 0) {
                    // remove first entry as it is not "additional"
                    additionalTextures.shift()
                }
            }
            return {
                painter: shader,
                parameters: shaderDef.parameters,
                additionalTextures,
                resultChannelLayout: shaderDef.resultChannelLayout ?? sourceImageChannelLayout,
            }
        }
        // preprocess
        let preProcessFn: ShaderDef | undefined = undefined
        switch (operator) {
            case "sum-square":
            case "mean-square":
            case "root-mean-square":
                preProcessFn = {
                    name: "preProcess(square)",
                    code: `
                        vec4 preProcess(vec4 value, ivec2 index) {
                            return value * value;
                        }`,
                }
                break
            case "custom":
                preProcessFn = custom!.preProcess
                break
            default:
                break
        }
        const preprocess = preProcessFn
            ? createShader({
                  ...preProcessFn,
                  code: `
                    ${preProcessFn.code}
                    
                    vec4 computeColor(ivec2 targetPixel) {
                        vec4 value = texelFetch0(targetPixel);
                        return preProcess(value, targetPixel);
                    }
                `,
              })
            : undefined

        // postprocess
        let postProcessFn: ShaderDef | undefined = undefined
        switch (operator) {
            case "root-mean-square":
                postProcessFn = {
                    name: "postProcess(sqrt)",
                    code: `
                    vec4 postProcess(vec4 value, ivec2 index) {
                        return sqrt(value);
                    }`,
                }
                break
            case "custom":
                postProcessFn = custom!.postProcess
                break
            default:
                break
        }
        const postprocess = postProcessFn
            ? createShader({
                  ...postProcessFn,
                  code: `
                    ${postProcessFn.code}
                    
                    vec4 computeColor(ivec2 targetPixel) {
                        vec4 value = texelFetch0(targetPixel);
                        return postProcess(value, targetPixel);
                    }
                `,
              })
            : undefined

        // we successively down-sample to 1x1
        let reduceFn: ShaderDef
        const wrapInReduceFn = (accuFn: string) => ({
            name: `reduce(${accuFn})`,
            code: `
            vec4 reduce(vec4 a, vec4 b) {
                return ${accuFn};
            }`,
        })
        switch (operator) {
            case "sum":
            case "sum-square":
                reduceFn = wrapInReduceFn("a + b")
                break
            case "mean":
            case "mean-square":
            case "root-mean-square":
                reduceFn = wrapInReduceFn("(a + b) * 0.5")
                break
            case "min":
                reduceFn = wrapInReduceFn("min(a, b)")
                break
            case "max":
                reduceFn = wrapInReduceFn("max(a, b)")
                break
            case "custom":
                reduceFn = custom!.reduce
                break
            default:
                assertNever(operator)
        }

        const reduce = createShader({
            ...reduceFn,
            name: "reduce",
            code: `
            uniform ivec2 u_patchSize;
        
            ${reduceFn.code}
            
            vec4 computeColor(ivec2 targetPixel) {
                ivec2 sourcePixel = targetPixel * 2;
                // if the patch size is odd we need to adjust the source pixel
                if ((u_patchSize.x & 1) != 0) {
                    sourcePixel.x -= sourcePixel.x / u_patchSize.x;
                }
                if ((u_patchSize.y & 1) != 0) {
                    sourcePixel.y -= sourcePixel.y / u_patchSize.y;
                }
                // check if neighbors are still in the same patch
                bool samePatchX = (sourcePixel.x + 1) % u_patchSize.x != 0;
                bool samePatchY = (sourcePixel.y + 1) % u_patchSize.y != 0;
                vec4 c00 = texelFetch0(sourcePixel + ivec2(0, 0), ADDRESS_MODE_BORDER);
                if (samePatchX && samePatchY) {
                    vec4 c01 = texelFetch0(sourcePixel + ivec2(0, 1), ADDRESS_MODE_BORDER);
                    vec4 c10 = texelFetch0(sourcePixel + ivec2(1, 0), ADDRESS_MODE_BORDER);
                    vec4 c11 = texelFetch0(sourcePixel + ivec2(1, 1), ADDRESS_MODE_BORDER);
                    vec4 c0 = reduce(c00, c01);
                    vec4 c1 = reduce(c10, c11);
                    return reduce(c0, c1);
                }
                else if (samePatchX) {
                    vec4 c10 = texelFetch0(sourcePixel + ivec2(1, 0), ADDRESS_MODE_BORDER);
                    return reduce(c00, c10);
                }
                else if (samePatchY) {
                    vec4 c01 = texelFetch0(sourcePixel + ivec2(0, 1), ADDRESS_MODE_BORDER);
                    return reduce(c00, c01);
                }
                else {
                    return c00;
                }
            }
        `,
        })
        // const originalSourceImageDataType = sourceImage.descriptor.dataType
        const resultImageDescriptor = {
            ...sourceImage.descriptor,
            width: undefined as number | undefined, // will be set later
            height: undefined as number | undefined, // will be set later
            dataType: undefined as DataType | undefined, // will be set later
        }
        if (preprocess) {
            const preprocessedImage = cmdQueue.createImage({
                ...resultImageDescriptor,
                width: sourceImage.descriptor.width,
                height: sourceImage.descriptor.height,
                channelLayout: preprocess.resultChannelLayout,
                dataType: intermediateDataType,
            })
            cmdQueue.paint(preprocess.painter, {
                parameters: preprocess.parameters,
                sourceImages: [sourceImage, ...preprocess.additionalTextures],
                resultImage: preprocessedImage,
            })
            sourceImage = preprocessedImage
        }
        for (let iteration = 0; ; iteration++) {
            const patchSize = {
                width: sourceImage.descriptor.width / sourceImage.descriptor.batchSize.width,
                height: sourceImage.descriptor.height / sourceImage.descriptor.batchSize.height,
            }
            if (!Number.isInteger(patchSize.width) || !Number.isInteger(patchSize.height)) {
                throw new Error("Batch size must be a divisor of the image size")
            }
            if (patchSize.width === 1 && patchSize.height === 1) {
                break
            }
            const downSampledSize = {
                width: Math.ceil(patchSize.width / 2) * sourceImage.descriptor.batchSize.width,
                height: Math.ceil(patchSize.height / 2) * sourceImage.descriptor.batchSize.height,
            }
            const nextSourceImage = cmdQueue.createImage({
                ...resultImageDescriptor,
                ...downSampledSize,
                channelLayout: reduce.resultChannelLayout,
                dataType: intermediateDataType,
            })
            cmdQueue.paint(reduce.painter, {
                parameters: {
                    ...reduce.parameters,
                    u_patchSize: {type: "int2", value: {x: patchSize.width, y: patchSize.height}},
                },
                sourceImages: [sourceImage, ...reduce.additionalTextures],
                resultImage: nextSourceImage,
            })
            sourceImage = nextSourceImage
        }
        if (postprocess) {
            const resultImage = cmdQueue.createImage({
                ...resultImageDescriptor,
                ...sourceImage.descriptor.batchSize,
                channelLayout: postprocess.resultChannelLayout,
                dataType: intermediateDataType,
            })
            cmdQueue.paint(postprocess.painter, {
                parameters: postprocess.parameters,
                sourceImages: [sourceImage, ...postprocess.additionalTextures],
                resultImage,
            })
            sourceImage = resultImage
        }
        // convert back to original data type
        sourceImage = convert(cmdQueue, {sourceImage, dataType: resultDataType})
        return sourceImage
    },

    ImgProc: ({cmdQueue, parameters: {sourceImage, operator, resultDataType}}) => {
        if (operator === "custom") {
            throw new Error("Custom reduce operator is not supported in ImgProc")
        }
        if (sourceImage.descriptor.batchSize.width !== 1 && sourceImage.descriptor.batchSize.height !== 1) {
            throw new Error("Batch size is not supported in ImgProc")
        }
        return cmdQueue.createImage(
            {
                width: 1,
                height: 1,
                channelLayout: sourceImage.descriptor.channelLayout,
                dataType: resultDataType ?? sourceImage.descriptor.dataType,
                options: sourceImage.descriptor.options,
                batchSize: {width: 1, height: 1},
            },
            {type: "reduce", operation: operator, input: sourceImage},
        )
    },
}

export function reduce(cmdQueue: ImageOpCommandQueue, parameters: ParameterType) {
    return runImageOp(cmdQueue, imageOpReduce, parameters)
}
