import {DeclareMaterialNode, materialSlots} from "@src/materials/declare-material-node"
import {z} from "zod"
import {color} from "@src/materials/types"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {getAll} from "@src/graph-system/utils"
import {
    threeColorBurnNode,
    threeColorDodgeNode,
    threeConvert,
    threeHsvToRgbNode,
    threeOverlayNode,
    threeRGBColorNode,
    threeRgbToHsvNode,
    threeValueNode,
} from "@src/materials/three-utils"

export class MixRGB extends DeclareMaterialNode(
    {
        returns: z.object({color: materialSlots}),
        inputs: z.object({
            color1: materialSlots.optional(),
            color2: materialSlots.optional(),
            fac: materialSlots.optional(),
        }),
        parameters: z.object({
            blendType: z
                .enum([
                    "MIX",
                    "ADD",
                    "SUBTRACT",
                    "MULTIPLY",
                    "SCREEN",
                    "DIVIDE",
                    "DIFFERENCE",
                    "EXCLUSION",
                    "DARKEN",
                    "LIGHTEN",
                    "OVERLAY",
                    "COLOR_DODGE",
                    "COLOR_BURN",
                    "HUE",
                    "SATURATION",
                    "VALUE",
                    "COLOR",
                    "SOFT_LIGHT",
                    "LINEAR_LIGHT",
                ])
                .optional(),
            color1: color.optional(),
            color2: color.optional(),
            fac: z.number().optional(),
            useClamp: z.boolean().optional(),
            useAlpha: z.boolean().optional(),
        }),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const {color1, color2, fac} = await getAll(inputs, get)
            const inputA = color1 ?? threeConvert(parameters.color1, threeRGBColorNode) ?? threeRGBColorNode({r: 0, g: 0, b: 0})
            const inputB = color2 ?? threeConvert(parameters.color2, threeRGBColorNode) ?? threeRGBColorNode({r: 0, g: 0, b: 0})
            const facNode = fac ?? threeConvert(parameters.fac, threeValueNode) ?? threeValueNode(1)
            const {blendType, useClamp} = parameters
            const operation = blendType ?? "MIX"

            const getColor = () => {
                if (operation === "MIX") {
                    return THREENodes.mix(inputA, inputB, facNode)
                } else if (operation === "LINEAR_LIGHT") {
                    const mulBNode = THREENodes.mul(inputB, threeValueNode(2.0))
                    const subBNode = THREENodes.sub(mulBNode, threeValueNode(1))
                    const facBNode = THREENodes.mul(facNode, subBNode)
                    return THREENodes.add(inputA, facBNode)
                } else if (operation === "SOFT_LIGHT") {
                    const invANode = THREENodes.sub(threeValueNode(1), inputA) // (1-A)
                    const invBNode = THREENodes.sub(threeValueNode(1), inputB) // (1-B)
                    const mulInvAInvBNode = THREENodes.mul(invANode, invBNode) // (1-A)(1-B)
                    const screenNode = THREENodes.sub(threeValueNode(1), mulInvAInvBNode) // 1-(1-A)(1-B)
                    const mulInvABNode = THREENodes.mul(invANode, inputB) // (1-A)B
                    const mulInvABANode = THREENodes.mul(mulInvABNode, inputA) // (1-A)BA
                    const mulAScreenNode = THREENodes.mul(inputA, screenNode) // A*scr
                    const additionNode = THREENodes.add(mulInvABANode, mulAScreenNode) // (1-A)BA+A*scr
                    return THREENodes.mix(inputA, additionNode, facNode) // (1-f)A+f[(1-A)BA+A*scr]
                } else if (operation === "MULTIPLY") {
                    const mulNode = THREENodes.mul(inputA, inputB)
                    return THREENodes.mix(inputA, mulNode, facNode)
                } else if (operation === "ADD") {
                    const mulNode = THREENodes.add(inputA, inputB)
                    return THREENodes.mix(inputA, mulNode, facNode)
                } else if (operation === "SUBTRACT") {
                    const mulNode = THREENodes.sub(inputA, inputB)
                    return THREENodes.mix(inputA, mulNode, facNode)
                } else if (operation === "DIFFERENCE") {
                    const subNode = THREENodes.sub(inputA, inputB) // A-B
                    const diffNode = THREENodes.abs(subNode) // |A-B|
                    return THREENodes.mix(inputA, diffNode, facNode) // (1-f)A+f|A-B|
                } else if (operation === "EXCLUSION") {
                    const addNode = THREENodes.add(inputA, inputB) // A+B
                    const mulNode = THREENodes.mul(threeValueNode(2), THREENodes.mul(inputA, inputB)) // 2*A*B
                    const exclusionNode = THREENodes.sub(addNode, mulNode) // A+B-2*A*B
                    return THREENodes.mix(inputA, exclusionNode, facNode) // (1-f)A+f*[A+B-2*A*B]
                } else if (operation === "LIGHTEN") {
                    const mulNode = THREENodes.mul(facNode, inputB)
                    return THREENodes.max(inputA, mulNode)
                } else if (operation === "DARKEN") {
                    const minNode = THREENodes.min(inputA, inputB) // min(A, B)
                    return THREENodes.mix(inputA, minNode, facNode) // (1-f)A+f*min(A,B)
                } else if (operation === "DIVIDE") {
                    const rNodeInpA = new THREENodes.SplitNode(inputA, "r")
                    const gNodeInpA = new THREENodes.SplitNode(inputA, "g")
                    const bNodeInpA = new THREENodes.SplitNode(inputA, "b")
                    const rNodeInpB = new THREENodes.SplitNode(inputB, "r")
                    const gNodeInpB = new THREENodes.SplitNode(inputB, "g")
                    const bNodeInpB = new THREENodes.SplitNode(inputB, "b")
                    const divNodeR = THREENodes.div(rNodeInpA, rNodeInpB)
                    const divNodeG = THREENodes.div(gNodeInpA, gNodeInpB)
                    const divNodeB = THREENodes.div(bNodeInpA, bNodeInpB)
                    const divResultR = THREENodes.cond(THREENodes.greaterThan(rNodeInpB, threeValueNode(0)), divNodeR, rNodeInpA)
                    const divResultG = THREENodes.cond(THREENodes.greaterThan(gNodeInpB, threeValueNode(0)), divNodeG, gNodeInpA)
                    const divResultB = THREENodes.cond(THREENodes.greaterThan(bNodeInpB, threeValueNode(0)), divNodeB, bNodeInpA)
                    const divNode = new THREENodes.JoinNode([divResultR, divResultG, divResultB])

                    return THREENodes.mix(inputA, divNode, facNode)
                } else if (operation === "SCREEN") {
                    const invANode = THREENodes.sub(threeValueNode(1), inputA) // (1-A)
                    const invBNode = THREENodes.sub(threeValueNode(1), inputB) // (1-B)
                    const invFacNode = THREENodes.sub(threeValueNode(1), facNode) // (1-f)
                    const mulInvFacInvANode = THREENodes.mul(invFacNode, invANode) // (1-f)(1-A)
                    const mulFacInvANode = THREENodes.mul(facNode, invANode) // f(1-A)
                    const mulInvAInvBNode = THREENodes.mul(mulFacInvANode, invBNode) // f(1-A)(1-B)
                    const addNode = THREENodes.add(mulInvFacInvANode, mulInvAInvBNode) // (1-f)(1-A)+f(1-A)(1-B)
                    return THREENodes.sub(threeValueNode(1), addNode) // 1-[(1-f)(1-A)+f(1-A)(1-B)]
                } else if (operation === "COLOR") {
                    const hsvNodeInpB = threeRgbToHsvNode(inputB)
                    const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")

                    // If part
                    const hsvNodeInpA = threeRgbToHsvNode(inputA)
                    const hNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "x")
                    const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
                    const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpB, sNodeHsvInpB, vNodeHsvInpA])
                    const newRgbNode = threeHsvToRgbNode(newHsvNode)
                    const ifNode = THREENodes.mix(inputA, newRgbNode, facNode)

                    return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpB, threeValueNode(0)), ifNode, inputA)
                } else if (operation === "HUE") {
                    const hsvNodeInpB = threeRgbToHsvNode(inputB)
                    const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")

                    // If part
                    const hsvNodeInpA = threeRgbToHsvNode(inputA)
                    const hNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "x")
                    const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")
                    const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
                    const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpB, sNodeHsvInpA, vNodeHsvInpA])
                    const newNodeRgb = threeHsvToRgbNode(newHsvNode)
                    const ifNode = THREENodes.mix(inputA, newNodeRgb, facNode)

                    return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpB, threeValueNode(0)), ifNode, inputA)
                } else if (operation === "SATURATION") {
                    const hsvNodeInpA = threeRgbToHsvNode(inputA)
                    const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")

                    // If part
                    const hsvNodeInpB = threeRgbToHsvNode(inputB)
                    const hNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "x")
                    const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
                    const sNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "y")
                    const mixNode = THREENodes.mix(sNodeHsvInpA, sNodeHsvInpB, facNode)
                    const ifNodeHsv = new THREENodes.JoinNode([hNodeHsvInpA, mixNode, vNodeHsvInpA])
                    const ifNodeRgb = threeHsvToRgbNode(ifNodeHsv)

                    return THREENodes.cond(THREENodes.greaterThan(sNodeHsvInpA, threeValueNode(0)), ifNodeRgb, inputA)
                } else if (operation === "VALUE") {
                    const hsvNodeInpA = threeRgbToHsvNode(inputA)
                    const hsvNodeInpB = threeRgbToHsvNode(inputB)
                    const hNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "x")
                    const sNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "y")
                    const vNodeHsvInpA = new THREENodes.SplitNode(hsvNodeInpA, "z")
                    const vNodeHsvInpB = new THREENodes.SplitNode(hsvNodeInpB, "z")
                    const mixNode = THREENodes.mix(vNodeHsvInpA, vNodeHsvInpB, facNode)
                    const newHsvNode = new THREENodes.JoinNode([hNodeHsvInpA, sNodeHsvInpA, mixNode])

                    return threeHsvToRgbNode(newHsvNode)
                } else if (operation === "COLOR_BURN") {
                    return threeColorBurnNode(inputA, inputB, facNode)
                } else if (operation === "COLOR_DODGE") {
                    return threeColorDodgeNode(inputA, inputB, facNode)
                } else if (operation === "OVERLAY") {
                    return threeOverlayNode(inputA, inputB, facNode)
                } else throw new Error(`Unsupported operation: ${operation}`)
            }

            const mixedColor = getColor()
            return {color: useClamp ? THREENodes.clamp(mixedColor, threeValueNode(0), threeValueNode(1)) : mixedColor}
        },
    },
) {}
