import {DeclareMaterialNode, materialSlots} from "#material-nodes/declare-material-node"
import {z} from "zod"
import {ApplyLUTNode, lutSize, threeConvert, threeRGBColorNode, threeValueNode} from "#material-nodes/three-utils"
import {color, vec2, Vec2} from "#material-nodes/types"
import {CubicBezierSpline, Knot, InterpolationMode, OutOfBoundsMode, SampledFunction} from "@cm/math"
import * as THREE from "three"
import {getAll} from "@cm/graph"
import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"

function computeRgbCurveSampledFunction(controlPoints: Vec2[]): SampledFunction {
    const numCurveSamples = 256
    const knots = controlPoints.map((point) => new Knot(point))
    const spline = new CubicBezierSpline(knots)
    const points = spline.evaluatePoints(numCurveSamples)
    return new SampledFunction(points, InterpolationMode.Linear, OutOfBoundsMode.Extrapolate)
}

// returns a [lutSize][3] array of numbers which map inputs for R, G and B (mapped between 0 and 1) to their respective output values
function computeRgbCurveLUT(lutSize: number, controlPoints: [Vec2[], Vec2[], Vec2[], Vec2[]]): number[][] {
    const sampledFunctions: SampledFunction[] = []
    for (let curveIndex = 0; curveIndex < 4; curveIndex++) {
        const sampledFunction = computeRgbCurveSampledFunction(controlPoints[curveIndex])
        sampledFunctions.push(sampledFunction)
    }
    const lut: number[][] = new Array(lutSize).fill(0).map(() => new Array(3).fill(0))
    for (let c = 0; c < 3; c++) {
        for (let i = 0; i < lut.length; i++) {
            const input = i / (lut.length - 1)
            const output = sampledFunctions[c].evaluate(sampledFunctions[3].evaluate(input)!)! // we use "extrapolate" so it can not be undefined
            lut[i][c] = output
        }
    }
    return lut
}

const ReturnTypeSchema = z.object({color: materialSlots})
const InputTypeSchema = z.object({color: materialSlots.optional(), fac: materialSlots.optional()})
const ParametersTypeSchema = z.object({
    color: color.optional(),
    fac: z.number().optional(),
    cyclesMappingTable: z.array(color).optional(),
    controlPoints: z.tuple([z.array(vec2), z.array(vec2), z.array(vec2), z.array(vec2)]).optional(),
})
export class RGBCurve extends DeclareMaterialNode(
    {
        returns: ReturnTypeSchema,
        inputs: InputTypeSchema,
        parameters: ParametersTypeSchema,
    },
    {
        toThree: async ({get, inputs, parameters, context}) => {
            const {color, fac} = await getAll(inputs, get)
            const colorValue = color ?? threeConvert(parameters.color, threeRGBColorNode) ?? threeRGBColorNode({r: 0, g: 0, b: 0})
            const facValue = fac ?? threeConvert(parameters.fac, threeValueNode) ?? threeValueNode(1)
            const {controlPoints} = parameters

            if (!controlPoints) return {color: colorValue}

            const lut = computeRgbCurveLUT(lutSize, controlPoints)
            const lut_rgba = lut.map((x) => [...x, 1])

            const texture = new THREE.DataTexture(new Float32Array(lut_rgba.flat()), lut_rgba.length, 1, THREE.RGBAFormat, THREE.FloatType)
            texture.minFilter = THREE.NearestFilter
            texture.magFilter = THREE.NearestFilter
            texture.wrapS = THREE.ClampToEdgeWrapping
            texture.wrapT = THREE.ClampToEdgeWrapping
            texture.anisotropy = 1
            texture.colorSpace = THREE.LinearSRGBColorSpace
            texture.needsUpdate = true
            context.onThreeCreatedTexture?.(texture)

            return {color: THREENodes.color(new ApplyLUTNode(THREENodes.vec4(colorValue), texture, facValue))}
        },
    },
) {}
