import {ImagePtr, ImagePtrReassignable} from "app/textures/texture-editor/operator-stack/image-op-system/image-ref"
import {ImageProcessingNodes as Nodes} from "@cm/lib/image-processing/image-processing-nodes"
import {ImageOpType} from "app/textures/texture-editor/operator-stack/image-op-system/detail/types"
import {assertNever} from "@cm/lib/utils/utils"
import {HalPainterImageCompositor} from "@common/models/hal/hal-painter-image-compositor"
import {Size2Like} from "@cm/lib/math/size2"

export type ParameterType = {
    sourceImage: ImagePtr
    operator: "sum" | "min" | "max" | "mean" | "sum-square" | "mean-square" | "root-mean-square"
    batchSize?: Size2Like // default: {width: 1, height: 1}
}

//type ReturnType = Color
export type ReturnType = ImagePtr // TODO this should be replaced by a Color above but ImgProc does currently not allow for this

export const imageOpReduce: ImageOpType<ParameterType, ReturnType> = {
    name: "Reduce",

    WebGL2: async ({context, parameters: {sourceImage, operator, batchSize}}) => {
        batchSize ??= {width: 1, height: 1}
        // preprocess
        let halPreprocess: HalPainterImageCompositor | undefined
        switch (operator) {
            case "sum-square":
            case "mean-square":
            case "root-mean-square":
                halPreprocess = await context.getOrCreateImageCompositor(`
                    vec4 computeColor(ivec2 targetPixel) {
                        vec4 c = texelFetch0(targetPixel);
                        return c * c;
                    }
                `)
                break
            default:
                break
        }

        // we successively down-sample to 1x1
        let accuFn2: string
        let accuFn4: string
        switch (operator) {
            case "sum":
            case "sum-square":
                accuFn2 = "a + b"
                accuFn4 = "a + b + c + d"
                break
            case "mean":
            case "mean-square":
            case "root-mean-square":
                accuFn2 = "(a + b) * 0.5"
                accuFn4 = "(a + b + c + d) * 0.25"
                break
            case "min":
                accuFn2 = "min(a, b)"
                accuFn4 = "min(min(min(a, b), c), d)"
                break
            case "max":
                accuFn2 = "max(a, b)"
                accuFn4 = "max(max(max(a, b), c), d)"
                break
            default:
                assertNever(operator)
        }

        // postprocess
        let halPostprocess: HalPainterImageCompositor | undefined
        switch (operator) {
            case "root-mean-square":
                halPostprocess = await context.getOrCreateImageCompositor(`
                    uniform uint u_numElements;

                    vec4 computeColor(ivec2 targetPixel) {
                        vec4 c = texelFetch0(targetPixel);
                        return sqrt(c);
                    }
                `)
                break
            default:
                break
        }
        const halAccu = await context.getOrCreateImageCompositor(`
            uniform ivec2 u_batchSize;
        
            vec4 reduce2(vec4 a, vec4 b) {
                return ${accuFn2};
            }

            vec4 reduce4(vec4 a, vec4 b, vec4 c, vec4 d) {
                return ${accuFn4};
            }
            
            bool isSameBatch(ivec2 sourcePixel, ivec2 offset) {
                ivec2 texelIndex = sourcePixel + offset;
                ivec2 imageSize = ivec2(u_imageSize[0]);
                ivec2 patchSize = imageSize / u_batchSize; 
                return texelIndex.x < imageSize.x && texelIndex.y < imageSize.y
                    && texelIndex.x / patchSize.x == sourcePixel.x / patchSize.x
                    && texelIndex.y / patchSize.y == sourcePixel.y / patchSize.y;
            }

            vec4 computeColor(ivec2 targetPixel) {
                ivec2 sourcePixel = targetPixel * 2;
                bool borderX = !isSameBatch(sourcePixel, ivec2(1, 0));
                bool borderY = !isSameBatch(sourcePixel, ivec2(0, 1));
                vec4 c00 = texelFetch0(sourcePixel + ivec2(0, 0), ADDRESS_MODE_CLAMP_TO_EDGE);
                if (borderX && borderY) {
                    return c00;
                }
                if (borderX) {
                    vec4 c01 = texelFetch0(sourcePixel + ivec2(0, 1), ADDRESS_MODE_CLAMP_TO_EDGE);
                    return reduce2(c00, c01);
                }
                else if (borderY) {
                    vec4 c10 = texelFetch0(sourcePixel + ivec2(1, 0), ADDRESS_MODE_CLAMP_TO_EDGE);
                    return reduce2(c00, c10);
                }
                else {
                    vec4 c01 = texelFetch0(sourcePixel + ivec2(0, 1), ADDRESS_MODE_CLAMP_TO_EDGE);
                    vec4 c10 = texelFetch0(sourcePixel + ivec2(1, 0), ADDRESS_MODE_CLAMP_TO_EDGE);
                    vec4 c11 = texelFetch0(sourcePixel + ivec2(1, 1), ADDRESS_MODE_CLAMP_TO_EDGE);
                    return reduce4(c00, c10, c01, c11);
                }
            }
        `)
        halAccu.setParameter("u_batchSize", {type: "int2", value: {x: batchSize.width, y: batchSize.height}})
        const sourceImagePtr = new ImagePtrReassignable(sourceImage)
        using originalSourceImageWebGL2 = await context.getImage(sourceImagePtr)
        const numElements = originalSourceImageWebGL2.ref.descriptor.width * originalSourceImageWebGL2.ref.descriptor.height
        if (halPreprocess) {
            using preprocessedImage = await context.createImage(originalSourceImageWebGL2.ref.descriptor)
            using preprocessedImageWebGL2 = await context.getImage(preprocessedImage)
            await halPreprocess.paint(preprocessedImageWebGL2.ref.halImage, originalSourceImageWebGL2.ref.halImage)
            sourceImagePtr.set(preprocessedImage)
        }
        for (let iteration = 0; ; iteration++) {
            using sourceImageWebGL2 = await context.getImage(sourceImagePtr)
            if (sourceImageWebGL2.ref.descriptor.width <= batchSize.width && sourceImageWebGL2.ref.descriptor.height <= batchSize.height) {
                break
            }
            const downSampledSize = {
                width: Math.max(batchSize.width, Math.ceil(sourceImageWebGL2.ref.descriptor.width / 2)),
                height: Math.max(batchSize.height, Math.ceil(sourceImageWebGL2.ref.descriptor.height / 2)),
            }
            using nextSourceImage = await context.createImage({
                ...downSampledSize,
                channelLayout: sourceImageWebGL2.ref.descriptor.channelLayout,
                format: "float32", // force float32 to reduce precision issues...   //sourceImageWebGL2.ref.descriptor.format,
                isSRGB: sourceImageWebGL2.ref.descriptor.isSRGB,
            })
            using nextSourceImageWebGL2 = await context.getImage(nextSourceImage)
            await halAccu.paint(nextSourceImageWebGL2.ref.halImage, sourceImageWebGL2.ref.halImage)
            sourceImagePtr.set(nextSourceImage)
        }
        if (halPostprocess) {
            using sourceImageWebGL2 = await context.getImage(sourceImagePtr)
            using resultImage = await context.createImage({
                ...batchSize,
                channelLayout: sourceImageWebGL2.ref.descriptor.channelLayout,
                format: sourceImageWebGL2.ref.descriptor.format,
                isSRGB: sourceImageWebGL2.ref.descriptor.isSRGB,
            })
            using resultImageWebGL2 = await context.getImage(resultImage)
            halPostprocess.setParameter("u_numElements", {type: "uint", value: numElements}, true)
            await halPostprocess.paint(resultImageWebGL2.ref.halImage, sourceImageWebGL2.ref.halImage)
            sourceImagePtr.set(resultImage)
        }
        return sourceImagePtr
    },

    ImgProc: async ({context, parameters: {sourceImage, operator}}) => {
        using sourceImageImgProc = await context.getImage(sourceImage)
        const resultNode = Nodes.reduce(operator, sourceImageImgProc.ref.node)
        return await context.createImage(
            {
                width: 1,
                height: 1,
                channelLayout: sourceImageImgProc.ref.descriptor.channelLayout,
                format: sourceImageImgProc.ref.descriptor.format,
                isSRGB: sourceImageImgProc.ref.descriptor.isSRGB,
            },
            resultNode,
        )
    },
}
