import {DeclareMaterialNode, materialSlots} from "@src/materials/declare-material-node"
import {z} from "zod"
import {vec3} from "@src/materials/types"
import {threeConvert, threeVec3Node} from "@src/materials/three-utils"
import {getAll} from "@src/graph-system/utils"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"

const blendNormals = new THREENodes.FunctionNode(`
vec3 blendNormals(vec3 color1, vec3 color2) {
    vec3 t = color1*vec3( 2.,  2., 2.) + vec3(-1., -1.,  0.);
    vec3 u = color2*vec3(-2., -2., 2.) + vec3( 1.,  1., -1.);
    return ((t*dot(t, u) - u*t.z) + 1.0) * 0.5;
}
`)

class BlendNormalsNode extends THREENodes.TempNode {
    constructor(
        public color1: THREENodes.Node,
        public color2: THREENodes.Node,
    ) {
        super("vec3")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        return THREENodes.call(blendNormals, {
            color1: this.color1,
            color2: this.color2,
        }).build(builder, type)
    }
}

export class BlendNormals extends DeclareMaterialNode(
    {
        returns: z.object({color: materialSlots}),
        inputs: z.object({
            color1: materialSlots.optional(),
            color2: materialSlots.optional(),
        }),
        parameters: z.object({
            color1: vec3.optional(),
            color2: vec3.optional(),
        }),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const {color1, color2} = await getAll(inputs, get)
            const color1Value = color1 ?? threeConvert(parameters.color1, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 1})
            const color2Value = color2 ?? threeConvert(parameters.color2, threeVec3Node) ?? threeVec3Node({x: 0, y: 0, z: 1})

            return {color: new BlendNormalsNode(color1Value, color2Value)}
        },
    },
) {}
