import {ImageProcessingNodes as Nodes} from "@cm/lib/image-processing/image-processing-nodes"
import {ImageOpType} from "app/textures/texture-editor/operator-stack/image-op-system/detail/types"
import {ImagePtr, isImagePtr} from "app/textures/texture-editor/operator-stack/image-op-system/image-ref"
import {toImgProcResultImage} from "app/textures/texture-editor/operator-stack/image-op-system/detail/utils-img-proc"
import {assertNever} from "@cm/lib/utils/utils"
import {HalPainterImageCompositor} from "@common/models/hal/hal-painter-image-compositor"
import {Size2Like} from "@cm/lib/math/size2"

type OperandType = ImagePtr | number

type Unary = {
    operand: OperandType
    operator: "sqrt" | "square" | "cos" | "sin" | "clip01" | "abs"
    batchSize?: Size2Like // default: {width: 1, height: 1}
}

type Binary = {
    operandA: OperandType
    operandB: OperandType
    operator: "+" | "-" | "*" | "/" | ">" | "<" | ">=" | "<=" | "==" | "max" | "min" | "atan2" | "pow" | "mod"
    batchSizeA?: Size2Like // default: {width: 1, height: 1}
    batchSizeB?: Size2Like // default: {width: 1, height: 1}
}

export type ParameterType = (Unary | Binary) & {
    resultImage?: ImagePtr
}

export type ReturnType = ImagePtr

export const imageOpMath: ImageOpType<ParameterType, ReturnType> = {
    name: "Math",

    WebGL2: async ({context, parameters}) => {
        const operator = parameters.operator
        const operandA = "operandA" in parameters ? parameters.operandA : parameters.operand
        const operandB = "operandB" in parameters ? parameters.operandB : undefined
        const batchSizeA = ("batchSize" in parameters ? parameters.batchSize : "batchSizeA" in parameters ? parameters.batchSizeA : undefined) ?? {
            width: 1,
            height: 1,
        }
        const batchSizeB = ("batchSize" in parameters ? parameters.batchSize : "batchSizeB" in parameters ? parameters.batchSizeB : undefined) ?? {
            width: 1,
            height: 1,
        }
        let resultImage = parameters.resultImage
        if (resultImage) {
            resultImage = new ImagePtr(resultImage)
        } else {
            const imageOperand = isImagePtr(operandA) ? operandA : isImagePtr(operandB) ? operandB : undefined
            if (!imageOperand) {
                throw new Error("If no resultImage is specified, at least one of the operands must be an image")
            }
            resultImage = await context.createImage(imageOperand)
        }
        using resultImageWebGL2 = await context.getImage(resultImage)
        let halMath: HalPainterImageCompositor
        const isUnaryOperation = operandB === undefined
        if (isUnaryOperation) {
            // unary operation
            const getShaderOperation = (operator: Unary["operator"]) => {
                switch (operator) {
                    case "sqrt":
                        return "sqrt(a)"
                    case "square":
                        return "a * a"
                    case "cos":
                        return "cos(a)"
                    case "sin":
                        return "sin(a)"
                    case "clip01":
                        return "clamp(a, 0.0, 1.0)"
                    case "abs":
                        return "abs(a)"
                    default:
                        assertNever(operator)
                }
            }
            halMath = await context.getOrCreateImageCompositor(`
                uniform ivec2 u_batchSize;
                
                vec4 computeColor(ivec2 targetPixel) {
                    ivec2 texelIndex = wrapBatchedTexelIndex(0, targetPixel, u_batchSize);
                    vec4 a = ${isImagePtr(operandA) ? "texelFetch0(texelIndex)" : `vec4(${operandA})`};
                    return ${getShaderOperation(operator as Unary["operator"])};
                }
            `)
            halMath.setParameter("u_batchSize", {type: "int2", value: {x: batchSizeA.width, y: batchSizeA.height}})
        } else {
            // binary operation
            const getShaderOperation = (operator: Binary["operator"]) => {
                switch (operator) {
                    case "+":
                        return "a + b"
                    case "-":
                        return "a - b"
                    case "*":
                        return "a * b"
                    case "/":
                        return "a / b"
                    case ">":
                        return "a > b ? 1.0 : 0.0"
                    case "<":
                        return "a < b ? 1.0 : 0.0"
                    case ">=":
                        return "a >= b ? 1.0 : 0.0"
                    case "<=":
                        return "a <= b ? 1.0 : 0.0"
                    case "==":
                        return "a == b ? 1.0 : 0.0"
                    case "max":
                        return "max(a, b)"
                    case "min":
                        return "min(a, b)"
                    case "atan2":
                        return "atan(a, b)"
                    case "pow":
                        return "pow(a, b)"
                    case "mod":
                        return "mod(a, b)"
                    default:
                        assertNever(operator)
                }
            }
            halMath = await context.getOrCreateImageCompositor(`
                uniform ivec2 u_batchSize0;
                uniform ivec2 u_batchSize1;
            
                vec4 computeColor(ivec2 targetPixel) {
                    ivec2 texelIndex0 = wrapBatchedTexelIndex(0, targetPixel, u_batchSize0);
                    ivec2 texelIndex1 = wrapBatchedTexelIndex(1, targetPixel, u_batchSize1);
                    vec4 a = ${isImagePtr(operandA) ? "texelFetch0(texelIndex0)" : `vec4(${operandA})`};
                    vec4 b = ${isImagePtr(operandB) ? "texelFetch1(texelIndex1)" : `vec4(${operandB})`};
                    return ${getShaderOperation(operator as Binary["operator"])};
                }
            `)
            halMath.setParameter("u_batchSize0", {type: "int2", value: {x: batchSizeA.width, y: batchSizeA.height}})
            halMath.setParameter("u_batchSize1", {type: "int2", value: {x: batchSizeB.width, y: batchSizeB.height}})
        }
        using operandAWebGL2 = isImagePtr(operandA) ? await context.getImage(operandA) : undefined
        using operandBWebGL2 = isImagePtr(operandB) ? await context.getImage(operandB) : undefined
        // if (operandAWebGL2) {
        //     assertSameSize(resultImageWebGL2.ref.descriptor, operandAWebGL2.ref.descriptor)
        // }
        // if (operandBWebGL2) {
        //     assertSameSize(resultImageWebGL2.ref.descriptor, operandBWebGL2.ref.descriptor)
        // }
        await halMath.paint(resultImageWebGL2.ref.halImage, [operandAWebGL2?.ref.halImage, operandBWebGL2?.ref.halImage])
        return resultImage
    },

    ImgProc: async ({context, parameters}) => {
        const operator = parameters.operator
        const operandA = "operandA" in parameters ? parameters.operandA : parameters.operand
        const operandB = "operandB" in parameters ? parameters.operandB : undefined
        const batchSizeA = ("batchSize" in parameters ? parameters.batchSize : "batchSizeA" in parameters ? parameters.batchSizeA : undefined) ?? {
            width: 1,
            height: 1,
        }
        const batchSizeB = ("batchSize" in parameters ? parameters.batchSize : "batchSizeB" in parameters ? parameters.batchSizeB : undefined) ?? {
            width: 1,
            height: 1,
        }
        if (batchSizeA.width !== 1 || batchSizeA.height !== 1 || batchSizeB.width !== 1 || batchSizeB.height !== 1) {
            throw new Error("Batch size not supported in image-processing")
        }
        let imageOperandA: ImagePtr
        if (typeof operandA === "number") {
            // image-processing does not allow the first operand to be a number; let's make it an image first
            if (!isImagePtr(operandB)) {
                throw new Error("Invalid operand")
            }
            using operandBImgProc = await context.getImage(operandB)
            const imageOperandANode = Nodes.math("constLike", operandBImgProc.ref.node, operandA)
            imageOperandA = await context.createImage(operandBImgProc.ref.descriptor, imageOperandANode)
        } else {
            imageOperandA = operandA
        }
        let imageOrNumberOperandB: Nodes.ImageNode | number | undefined
        if (typeof operandB === "number") {
            imageOrNumberOperandB = operandB
        } else {
            const operandBImgProc = operandB ? await context.getImage(operandB) : undefined
            imageOrNumberOperandB = operandBImgProc?.ref.node
            operandBImgProc?.release()
        }
        using imageOperandAImgProc = await context.getImage(imageOperandA)
        const resultNode: Nodes.Math = {type: "math", operation: operator, firstInput: imageOperandAImgProc.ref.node, secondInput: imageOrNumberOperandB}
        using result = await context.createImage(imageOperandAImgProc.ref.descriptor, resultNode)
        return await toImgProcResultImage(context, result, parameters.resultImage)
    },
}
