import {DeclareMaterialNode, materialSlots} from "@src/materials/declare-material-node"
import {z} from "zod"
import {Vec3, vec3} from "@src/materials/types"
import * as THREE from "three"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"

class Vec3TransformNode extends THREENodes.TempNode {
    constructor(
        public input: THREENodes.Node,
        public matrix: THREENodes.Node,
    ) {
        super("vec3")
    }

    override generate(builder: THREENodes.NodeBuilder, output?: string | null) {
        const type = this.getNodeType(builder) ?? "vec3"

        const input = this.input.build(builder, "vec3")
        const matrix = this.matrix.build(builder, "mat4")
        const result = `((${matrix}) * vec4((${input}), 1.0)).xyz`
        return builder.format(result, type, output as THREENodes.NodeTypeOption)
    }
}

export function getMappingMatrix(parameters: Mapping["parameters"]["parameters"]) {
    const location = parameters.location ?? {x: 0, y: 0, z: 0}
    const rotation = parameters.rotation ?? {x: 0, y: 0, z: 0}
    const scale = parameters.scale ?? {x: 1, y: 1, z: 1}
    const vectorType = parameters.vectorType ?? "POINT"

    const matrix = new THREE.Matrix4().identity()
    switch (vectorType) {
        case "POINT":
            // Operations on vector: Scale, then rotate, then translate
            // effective expression: ((I * Translation) * Rotation) * Scale) <- (vector)
            matrix.identity()
            matrix.multiply(new THREE.Matrix4().makeTranslation(location.x, location.y, location.z))
            matrix.multiply(new THREE.Matrix4().makeRotationFromEuler(new THREE.Euler(rotation.x, rotation.y, rotation.z, "XYZ")))
            matrix.multiply(new THREE.Matrix4().makeScale(scale.x, scale.y, scale.z))
            break
        case "VECTOR":
            // same as POINT, but with no translation
            matrix.identity()
            matrix.multiply(new THREE.Matrix4().makeRotationFromEuler(new THREE.Euler(rotation.x, rotation.y, rotation.z, "XYZ")))
            matrix.multiply(new THREE.Matrix4().makeScale(scale.x, scale.y, scale.z))
            break
        case "NORMAL":
            // divide by scale, then rotate, then normalize? (see Cycles implementation... not sure why it is supposed to work this way.)
            matrix.identity()
            matrix.multiply(new THREE.Matrix4().makeRotationFromEuler(new THREE.Euler(rotation.x, rotation.y, rotation.z, "XYZ")))
            matrix.multiply(new THREE.Matrix4().makeScale(1 / scale.x, 1 / scale.y, 1 / scale.z))
            break
        case "TEXTURE":
            // inverse of POINT
            matrix.identity()
            matrix.multiply(new THREE.Matrix4().makeScale(1 / scale.x, 1 / scale.y, 1 / scale.z))
            matrix.multiply(new THREE.Matrix4().makeRotationFromEuler(new THREE.Euler(-rotation.x, -rotation.y, -rotation.z, "ZYX")))
            matrix.multiply(new THREE.Matrix4().makeTranslation(-location.x, -location.y, -location.z))
            break
        default:
            throw new Error(`Unhandled mapping type: ${vectorType}`)
    }

    return matrix
}

export class Mapping extends DeclareMaterialNode(
    {
        returns: z.object({vector: materialSlots}),
        inputs: z.object({
            vector: materialSlots.optional(),
        }),
        parameters: z.object({
            location: vec3.optional(),
            rotation: vec3.optional(),
            scale: vec3.optional(),
            vectorType: z.enum(["POINT", "VECTOR", "NORMAL", "TEXTURE"]).optional(),
        }),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const vector = await get(inputs.vector)

            if (!vector) console.warn("Input to Mapping node not connected, defaulting to UV!")

            const vectorValue = vector ?? THREENodes.uv(0)
            const matrix = getMappingMatrix(parameters)

            const matrixNode = THREENodes.mat4(matrix)
            const transformedVector = new Vec3TransformNode(vectorValue, matrixNode)
            const outputVector = parameters.vectorType === "NORMAL" ? THREENodes.normalize(transformedVector) : transformedVector

            return {vector: outputVector}
        },
    },
) {}
