import {ImageOpType, runImageOp} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op"
import {Batching, DataType, ImageDescriptor, ImageRef, isImageRef} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-ref"
import {assertNever} from "@cm/utils"
import {ColorLike, isColorLike, Vector4Like} from "@cm/math"
import {ImageOpCommandQueue} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/image-op-command-queue"
import {getMostPreciseDataType} from "@app/textures/texture-editor/operator-stack/image-op-system/detail/utils"
import {HalPainterParameterValueType} from "@common/models/hal/hal-painter/types"

type OperandType = ImageRef | number | ColorLike

type Unary = {
    operand: OperandType
    operator: "sqrt" | "square" | "cos" | "sin" | "clip01" | "abs" | "isNaN"
}

type Binary = {
    operandA: OperandType
    operandB: OperandType
    operator: "+" | "-" | "*" | "/" | "/safe" | ">" | "<" | ">=" | "<=" | "==" | "max" | "min" | "atan2" | "pow" | "mod"
}

export type ParameterType = (Unary | Binary) & {
    resultImageOrDataType?: ImageRef | DataType
}

export type ReturnType = ImageRef

const imageOpMath: ImageOpType<ParameterType, ReturnType> = {
    name: "Math",

    WebGL2: ({cmdQueue, parameters}) => {
        const operator = parameters.operator
        const operandA = "operandA" in parameters ? parameters.operandA : parameters.operand
        const operandB = "operandB" in parameters ? parameters.operandB : undefined
        const numBatchesA = (isImageRef(operandA) ? operandA.descriptor.batching?.batchSize : undefined) ?? {
            width: 1,
            height: 1,
        }
        const numBatchesB = (isImageRef(operandB) ? operandB.descriptor.batching?.batchSize : undefined) ?? {
            width: 1,
            height: 1,
        }
        const colorLikeToVec4 = (color: ColorLike | number): Vector4Like => {
            if (typeof color === "number") {
                return {x: color, y: color, z: color, w: color}
            } else {
                return {x: color.r, y: color.g, z: color.b, w: color.a ?? 1}
            }
        }
        const shaderHelperFunctions = `
            ivec2 wrapBatchedTexelIndex(ivec2 patchTexelIndex, ivec2 batchIndex, ivec2 patchSize, ivec2 batchSize) {
                batchIndex.x = wrapInt(batchIndex.x, batchSize.x);
                batchIndex.y = wrapInt(batchIndex.y, batchSize.y);
                patchTexelIndex.x = wrapInt(patchTexelIndex.x, patchSize.x);
                patchTexelIndex.y = wrapInt(patchTexelIndex.y, patchSize.y);
                return batchIndex * patchSize + patchTexelIndex;
            }
        `
        let resultImageOrDataType = parameters.resultImageOrDataType
        if (!isImageRef(resultImageOrDataType)) {
            const imageOperandDescriptorA = isImageRef(operandA) ? operandA.descriptor : undefined
            const imageOperandDescriptorB = isImageRef(operandB) ? operandB.descriptor : undefined
            const imageOperandDescriptor = imageOperandDescriptorA ?? imageOperandDescriptorB
            if (!imageOperandDescriptor) {
                throw new Error("If no resultImage is specified, at least one of the operands must be an image")
            }
            const patchSizeA = imageOperandDescriptorA
                ? {
                      width: Math.ceil(imageOperandDescriptorA.width / numBatchesA.width),
                      height: Math.ceil(imageOperandDescriptorA.height / numBatchesA.height),
                  }
                : {width: 1, height: 1}
            const patchSizeB = imageOperandDescriptorB
                ? {
                      width: Math.ceil(imageOperandDescriptorB.width / numBatchesB.width),
                      height: Math.ceil(imageOperandDescriptorB.height / numBatchesB.height),
                  }
                : {width: 1, height: 1}
            const maxPatchSize = {width: Math.max(patchSizeA.width, patchSizeB.width), height: Math.max(patchSizeA.height, patchSizeB.height)}
            const resultBatchSize = {width: Math.max(numBatchesA.width, numBatchesB.width), height: Math.max(numBatchesA.height, numBatchesB.height)}
            const resultSize = {width: resultBatchSize.width * maxPatchSize.width, height: resultBatchSize.height * maxPatchSize.height}
            const resultBatching: Batching | undefined =
                resultBatchSize.width === 1 && resultBatchSize.height === 1 ? undefined : {patchSize: maxPatchSize, batchSize: resultBatchSize}
            // choose the channel layout with the most channels
            const resultChannelLayout =
                imageOperandDescriptorA && imageOperandDescriptorB
                    ? imageOperandDescriptorA.channelLayout.length > imageOperandDescriptorB.channelLayout.length
                        ? imageOperandDescriptorA.channelLayout
                        : imageOperandDescriptorB.channelLayout
                    : imageOperandDescriptor.channelLayout
            // if not explicitly requested, choose the data type with the most precision
            resultImageOrDataType ??=
                imageOperandDescriptorA && imageOperandDescriptorB
                    ? getMostPreciseDataType(imageOperandDescriptorA.dataType, imageOperandDescriptorB.dataType)
                    : imageOperandDescriptor.dataType
            const resultImageDescriptor: ImageDescriptor = {
                ...imageOperandDescriptor,
                ...resultSize,
                batching: resultBatching,
                channelLayout: resultChannelLayout,
                dataType: resultImageOrDataType,
            }
            resultImageOrDataType = cmdQueue.createImage(resultImageDescriptor)
        }
        const isUnaryOperation = operandB === undefined
        if (isUnaryOperation) {
            // unary operation
            const getShaderOperation = (operator: Unary["operator"]) => {
                switch (operator) {
                    case "sqrt":
                        return "sqrt(a)"
                    case "square":
                        return "a * a"
                    case "cos":
                        return "cos(a)"
                    case "sin":
                        return "sin(a)"
                    case "clip01":
                        return "clamp(a, 0.0, 1.0)"
                    case "abs":
                        return "abs(a)"
                    case "isNaN":
                        return "vec4(isnan(a.x) || isinf(a.x) ? 1.0 : 0.0, isnan(a.y) || isinf(a.y) ? 1.0 : 0.0, isnan(a.z) || isinf(a.z) ? 1.0 : 0.0, isnan(a.w) || isinf(a.w) ? 1.0 : 0.0)"
                    default:
                        assertNever(operator)
                }
            }
            const painter = cmdQueue.createPainter(
                "compositor",
                "mathUnary",
                `
                uniform ivec2 u_batchSize;
                uniform vec4 u_constantValue;
                
                ${shaderHelperFunctions}
                
                vec4 computeColor(ivec2 targetPixel) {
                    ivec2 patchSize = ivec2(u_imageSize[0]) / u_batchSize;
                    ivec2 batchIndex = targetPixel / patchSize;
                    ivec2 patchTexelIndex = targetPixel % patchSize;
                    ivec2 texelIndex = wrapBatchedTexelIndex(patchTexelIndex, batchIndex, patchSize, u_batchSize);
                    vec4 a = ${isImageRef(operandA) ? "texelFetch0(texelIndex)" : "u_constantValue"};
                    return ${getShaderOperation(operator as Unary["operator"])};
                }
            `,
            )
            const parameters: {[key: string]: HalPainterParameterValueType} = {
                u_batchSize: {type: "int2", value: {x: numBatchesA.width, y: numBatchesA.height}},
            }
            if (!isImageRef(operandA)) {
                parameters.u_constantValue = {type: "float4", value: colorLikeToVec4(operandA)}
            }
            cmdQueue.paint(painter, {
                parameters,
                sourceImages: [isImageRef(operandA) ? operandA : undefined, isImageRef(operandB) ? operandB : undefined],
                resultImage: resultImageOrDataType,
            })
        } else {
            // binary operation
            const vectorize = (operator: ">" | "<" | ">=" | "<=" | "==") =>
                `vec4(a.x ${operator} b.x ? 1 : 0, a.y ${operator} b.y ? 1 : 0, a.z ${operator} b.z ? 1 : 0, a.w ${operator} b.w ? 1 : 0)`
            const getShaderOperation = (operator: Binary["operator"]) => {
                switch (operator) {
                    case "+":
                        return "a + b"
                    case "-":
                        return "a - b"
                    case "*":
                        return "a * b"
                    case "/":
                        return "a / b"
                    case "/safe":
                        return "vec4(b.x == 0.0 ? 0.0 : a.x / b.x, b.y == 0.0 ? 0.0 : a.y / b.y, b.z == 0.0 ? 0.0 : a.z / b.z, b.w == 0.0 ? 0.0 : a.w / b.w)"
                    case ">":
                    case "<":
                    case ">=":
                    case "<=":
                    case "==":
                        return vectorize(operator)
                    case "max":
                        return "max(a, b)"
                    case "min":
                        return "min(a, b)"
                    case "atan2":
                        return "atan(a, b)"
                    case "pow":
                        return "pow(a, b)"
                    case "mod":
                        return "mod(a, b)"
                    default:
                        assertNever(operator)
                }
            }
            const painter = cmdQueue.createPainter(
                "compositor",
                "mathBinary",
                `
                uniform ivec2 u_batchSize0;
                uniform ivec2 u_batchSize1;
                uniform vec4 u_constantValue0;
                uniform vec4 u_constantValue1;
                
                ${shaderHelperFunctions}
            
                vec4 computeColor(ivec2 targetPixel) {
                    ivec2 patchSize0 = ivec2(u_imageSize[0]) / u_batchSize0;
                    ivec2 patchSize1 = ivec2(u_imageSize[1]) / u_batchSize1;
                    ivec2 maxPatchSize = max(patchSize0, patchSize1);
                    ivec2 maxBatchSize = max(u_batchSize0, u_batchSize1);
                    ivec2 batchIndex = targetPixel / maxPatchSize;
                    ivec2 patchTexelIndex = targetPixel % maxPatchSize;
                    ivec2 texelIndex0 = wrapBatchedTexelIndex(patchTexelIndex, batchIndex, patchSize0, u_batchSize0);
                    ivec2 texelIndex1 = wrapBatchedTexelIndex(patchTexelIndex, batchIndex, patchSize1, u_batchSize1);
                    vec4 a = ${isImageRef(operandA) ? "texelFetch0(texelIndex0)" : "u_constantValue0"};
                    vec4 b = ${isImageRef(operandB) ? "texelFetch1(texelIndex1)" : "u_constantValue1"};
                    return ${getShaderOperation(operator as Binary["operator"])};
                }
            `,
            )
            const parameters: {[key: string]: HalPainterParameterValueType} = {
                u_batchSize0: {type: "int2", value: {x: numBatchesA.width, y: numBatchesA.height}},
                u_batchSize1: {type: "int2", value: {x: numBatchesB.width, y: numBatchesB.height}},
            }
            if (!isImageRef(operandA)) {
                parameters.u_constantValue0 = {type: "float4", value: colorLikeToVec4(operandA)}
            }
            if (!isImageRef(operandB)) {
                parameters.u_constantValue1 = {type: "float4", value: colorLikeToVec4(operandB)}
            }
            cmdQueue.paint(painter, {
                parameters,
                sourceImages: [isImageRef(operandA) ? operandA : undefined, isImageRef(operandB) ? operandB : undefined],
                resultImage: resultImageOrDataType,
            })
        }
        return resultImageOrDataType
    },

    ImgProc: ({cmdQueue, parameters}) => {
        const operator = parameters.operator
        const operandA = "operandA" in parameters ? parameters.operandA : parameters.operand
        const operandB = "operandB" in parameters ? parameters.operandB : undefined
        const numBatchesA = (isImageRef(operandA) ? operandA.descriptor.batching?.batchSize : undefined) ?? {
            width: 1,
            height: 1,
        }
        const numBatchesB = (isImageRef(operandB) ? operandB.descriptor.batching?.batchSize : undefined) ?? {
            width: 1,
            height: 1,
        }
        if (numBatchesA.width !== 1 || numBatchesA.height !== 1 || numBatchesB.width !== 1 || numBatchesB.height !== 1) {
            throw new Error("Batch size not supported in image-processing")
        }
        let imageOperandA: ImageRef
        if (typeof operandA === "number") {
            // image-processing does not allow the first operand to be a number; let's make it an image first
            if (!isImageRef(operandB)) {
                throw new Error("Invalid operand")
            }
            imageOperandA = cmdQueue.createImage(operandB.descriptor, {type: "math", operation: "constLike", firstInput: operandB, secondInput: operandA})
        } else if (isColorLike(operandA)) {
            throw new Error("Color operands currently not supported in image-processing")
        } else {
            imageOperandA = operandA
        }
        if (isColorLike(operandB)) {
            throw new Error("Color operands currently not supported in image-processing")
        }
        const result = cmdQueue.createImage(imageOperandA.descriptor, {type: "math", operation: operator, firstInput: imageOperandA, secondInput: operandB})
        return cmdQueue.copyToResultImage(result, parameters.resultImageOrDataType)
    },
}

export function math(cmdQueue: ImageOpCommandQueue, parameters: ParameterType) {
    return runImageOp(cmdQueue, imageOpMath, parameters)
}
