import {DeclareMaterialNode, materialSlots} from "@src/materials/declare-material-node"
import {z} from "zod"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {threeRGBColorNode} from "@src/materials/three-utils"

export class TexGradient extends DeclareMaterialNode(
    {
        returns: z.object({color: materialSlots, fac: materialSlots}),
        inputs: z.object({vector: materialSlots.optional()}),
        parameters: z.object({
            type: z.enum(["Linear", "Quadratic", "Easing", "Diagonal", "Spherical", "Quadratic sphere", "Radial"]).optional(),
        }),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const type = parameters.type ?? "Linear"
            const vector = await get(inputs.vector)
            const vectorValue = vector ?? THREENodes.uv(0)

            const black = threeRGBColorNode({r: 0.0, g: 0.0, b: 0.0})
            const white = threeRGBColorNode({r: 1.0, g: 1.0, b: 1.0})

            const xComponent = new THREENodes.SplitNode(vectorValue, "x")
            const yComponent = new THREENodes.SplitNode(vectorValue, "y")

            if (type === "Linear") {
                const color = THREENodes.mix(black, white, xComponent)
                return {color, fac: xComponent}
            } else if (type === "Quadratic") {
                const quadraticComponent = THREENodes.pow(xComponent, 2)
                const color = THREENodes.mix(black, white, quadraticComponent)
                return {color, fac: quadraticComponent}
            } else if (type === "Easing") {
                const easingComponent = THREENodes.smoothstep(0.0, 1.0, xComponent)
                const color = THREENodes.mix(black, white, easingComponent)
                return {color, fac: easingComponent}
            } else if (type === "Diagonal") {
                const sumXY = THREENodes.add(xComponent, yComponent)
                const diagonalComponent = THREENodes.div(sumXY, 2)
                const color = THREENodes.mix(black, white, diagonalComponent)
                return {color, fac: diagonalComponent}
            } else if (type === "Spherical") {
                const length = THREENodes.length(vectorValue)
                const inverseLength = THREENodes.oneMinus(length)
                const color = THREENodes.mix(black, white, inverseLength)
                return {color, fac: inverseLength}
            } else if (type === "Quadratic sphere") {
                const length = THREENodes.length(vectorValue)
                const quadraticLength = THREENodes.pow(length, 2)
                const inverseQuadraticLength = THREENodes.oneMinus(quadraticLength)
                const color = THREENodes.mix(black, white, inverseQuadraticLength)
                return {color, fac: inverseQuadraticLength}
            } else if (type === "Radial") {
                const angle = THREENodes.atan2(yComponent, xComponent)
                const normalizedAngle = THREENodes.mul(THREENodes.add(angle, Math.PI), 1 / (2 * Math.PI)) // Normalize the angle from [-PI, PI] to [0, 1]
                const color = THREENodes.mix(black, white, normalizedAngle)
                return {color, fac: normalizedAngle}
            }

            throw new Error(`Unhandled gradient type: ${type}`)
        },
    },
) {}
