import * as THREENodes from "three/examples/jsm/nodes/Nodes"

const rgb2hsvNode = new THREENodes.FunctionNode(`
vec3 rgb2hsv(vec3 rgb) {
    vec3 hsv;
    float rgbmax = max(rgb.x, max(rgb.y, rgb.z));
    float rgbmin = min(rgb.x, min(rgb.y, rgb.z));
    float delta = rgbmax - rgbmin;
    hsv.z = rgbmax;

    if (rgbmax == 0.0)
        hsv = vec3(0.0, 0.0, 0.0);
    else
        hsv.y = delta / rgbmax;

    vec3 c = (vec3(rgbmax, rgbmax, rgbmax) - rgb) / delta;
    if (hsv.y == 0.0)
        hsv.x = 0.0;
    else
        if (rgb.x == rgbmax)
            hsv.x = c.z - c.y;
        else if (rgb.y == rgbmax)
            hsv.x = 2.0 + c.x - c.z;
        else
            hsv.x = 4.0 + c.y - c.x;

        hsv.x /= 6.0;
        if (hsv.x < 0.0)
            hsv.x += 1.0;

    return hsv;
}
`)

const hsv2rgbNode = new THREENodes.FunctionNode(`
vec3 hsv2rgb(vec3 hsv) {
    vec3 rgb;
    float h = hsv.x;
    float s = hsv.y;
    float v = hsv.z;

    if (s == 0.0)
        rgb = vec3(v, v, v);
    else
        if (h == 1.0)
            h = 0.0;
        h *= 6.0;
        float i = floor(h);
        float f = h - i;
        rgb = vec3(f, f, f);
        float p = v * (1.0 - s);
        float q = v * (1.0 - s * f);
        float t = v * (1.0 - s * (1.0 - f));

        if (i == 0.0)
            rgb = vec3(v, t, p);
        else if (i == 1.0)
            rgb = vec3(q, v, p);
        else if (i == 2.0)
            rgb = vec3(p, v, t);
        else if (i == 3.0)
            rgb = vec3(p, q, v);
        else if (i == 4.0)
            rgb = vec3(t, p, v);
        else
            rgb = vec3(v, p, q);

    return rgb;
}
`)

export function rgb2hsv(rgb: THREENodes.Node) {
    return THREENodes.call(rgb2hsvNode, {rgb})
}

export function hsv2rgb(hsv: THREENodes.Node) {
    return THREENodes.call(hsv2rgbNode, {hsv})
}

const adjustFnNode = new THREENodes.FunctionNode(`
vec3 adjust(vec3 hsvIn, float h, float s, float v) {
    vec3 hsvOut;
    hsvOut.x = mod(hsvIn.x + h + 0.5, 1.0);
    hsvOut.y = max(0.0, min(1.0, hsvIn.y * s));
    hsvOut.z = hsvIn.z * v;
    return hsvOut;
}
`)

const facAndClampFnNode = new THREENodes.FunctionNode(`
vec3 facAndClamp(vec3 rgbIn, vec3 rgbAdj, float f) {
    vec3 rgbOut;
    rgbOut.x = f * rgbAdj.x + (1.0 - f) * rgbIn.x;
    rgbOut.y = f * rgbAdj.y + (1.0 - f) * rgbIn.y;
    rgbOut.z = f * rgbAdj.z + (1.0 - f) * rgbIn.z;

    rgbOut.x = max(rgbOut.x, 0.0);
    rgbOut.y = max(rgbOut.y, 0.0);
    rgbOut.z = max(rgbOut.z, 0.0);

    return rgbOut;
}
`)

export class HSVNode extends THREENodes.TempNode {
    constructor(
        public rgbInput: THREENodes.Node,
        public hue: THREENodes.Node,
        public saturation: THREENodes.Node,
        public value: THREENodes.Node,
        public fac: THREENodes.Node,
    ) {
        super("vec3")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        const hsvInput = rgb2hsv(this.rgbInput)
        const hsvOutput = THREENodes.call(adjustFnNode, {
            hsvIn: hsvInput,
            h: this.hue,
            s: this.saturation,
            v: this.value,
        })
        const rgbOutput = THREENodes.call(facAndClampFnNode, {rgbIn: this.rgbInput, rgbAdj: hsv2rgb(hsvOutput), f: this.fac})
        return rgbOutput.build(builder, type)
    }
}
