import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {Color, Vec2, Vec3, Vec4} from "@src/materials/types"
import * as THREE from "three"

export const threeRGBColorNode = (color: Color) => {
    return THREENodes.color(new THREE.Color(color.r, color.g, color.b))
}

export const threeValueNode = (value: number) => {
    return THREENodes.float(value)
}

export const threeVec2Node = (value: Vec2) => {
    return THREENodes.vec2(new THREE.Vector2(value.x, value.y))
}

export const threeVec3Node = (value: Vec3) => {
    return THREENodes.vec3(new THREE.Vector3(value.x, value.y, value.z))
}

export const threeVec4Node = (value: Vec4) => {
    return THREENodes.vec4(new THREE.Vector4(value.x, value.y, value.z, value.w))
}

export const threeConvert = <T, R>(value: T | undefined, converter: (value: T) => R, checker?: (value: T) => boolean): R | undefined => {
    if (value === undefined) return undefined
    if (checker !== undefined && !checker(value)) return undefined
    return converter(value)
}

const colorBurnNode = new THREENodes.FunctionNode(`
vec3 colorBurn(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
 
    float tmp = (1.0 - f) + f * inputB.x;
    if (tmp <= 0.0) {
        resultColor.x = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.x) / tmp)) < 0.0) {
        resultColor.x = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.x = 1.0;
    }
    else {
        resultColor.x = tmp;
    }
 
    tmp = (1.0 - f) + f * inputB.y;
    if (tmp <= 0.0) {
        resultColor.y = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.y) / tmp)) < 0.0) {
        resultColor.y = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.y = 1.0;
    }
    else {
        resultColor.y = tmp;
    }
 
    tmp = (1.0 - f) + f * inputB.z;
    if (tmp <= 0.0) {
        resultColor.z = 0.0;
    }
    else if ((tmp = (1.0 - (1.0 - inputA.z) / tmp)) < 0.0) {
        resultColor.z = 0.0;
    }
    else if (tmp > 1.0) {
        resultColor.z = 1.0;
    }
    else {
        resultColor.z = tmp;
    }
 
    return resultColor;
}
`)

export function threeColorBurnNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(colorBurnNode, [inputA, inputB, fac])
}

const colorDodgeNode = new THREENodes.FunctionNode(`
vec3 colorDodge(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
    float tmp;
 
    resultColor = inputA;
 
    if (inputA.x > 0.0) {
        tmp = 1.0 - f * inputB.x;
        if (tmp <= 0.0) {
            resultColor.x = 1.0;
        }
        else if ((tmp = (inputA.x / tmp)) > 1.0) {
            resultColor.x = 1.0;
        }
        else {
            resultColor.x = tmp;
        }
    }
 
    if (inputA.y > 0.0) {
        tmp = 1.0 - f * inputB.y;
        if (tmp <= 0.0) {
            resultColor.y = 1.0;
        }
        else if ((tmp = (inputA.y / tmp)) > 1.0) {
            resultColor.y = 1.0;
        }
        else {
            resultColor.y = tmp;
        }
    }
 
    if (inputA.z > 0.0) {
        tmp = 1.0 - f * inputB.z;
        if (tmp <= 0.0) {
            resultColor.z = 1.0;
        }
        else if ((tmp = (inputA.z / tmp)) > 1.0) {
            resultColor.z = 1.0;
        }
        else {
            resultColor.z = tmp;
        }
    }
 
    return resultColor;
}
`)

export function threeColorDodgeNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(colorDodgeNode, [inputA, inputB, fac])
}

const overlayNode = new THREENodes.FunctionNode(`
vec3 overlay(vec3 inputA, vec3 inputB, float f) {
    vec3 resultColor;
    float invF = 1.0 - f;
 
    if (inputA.x < 0.5) {
        resultColor.x = inputA.x * (invF + 2.0 * f * inputB.x);
    }
    else {
        resultColor.x = 1.0 - (invF + 2.0 * f * (1.0 - inputB.x)) * (1.0 - inputA.x);
    }
 
    if (inputA.y < 0.5) {
        resultColor.y = inputA.y * (invF + 2.0 * f * inputB.y);
    }
    else {
        resultColor.y = 1.0 - (invF + 2.0 * f * (1.0 - inputB.y)) * (1.0 - inputA.y);
    }
 
    if (inputA.z < 0.5) {
        resultColor.z = inputA.z * (invF + 2.0 * f * inputB.z);
    }
    else {
        resultColor.z = 1.0 - (invF + 2.0 * f * (1.0 - inputB.z)) * (1.0 - inputA.z);
    }
 
    return resultColor;
}
`)

export function threeOverlayNode(inputA: THREENodes.Node, inputB: THREENodes.Node, fac: THREENodes.Node): THREENodes.Node {
    return THREENodes.call(overlayNode, [inputA, inputB, fac])
}

const rgbToHsvNode = new THREENodes.FunctionNode(`
vec3 rgb2hsv(vec3 rgb) {
    vec3 hsv;
    float rgbmax = max(rgb.x, max(rgb.y, rgb.z));
    float rgbmin = min(rgb.x, min(rgb.y, rgb.z));
    float delta = rgbmax - rgbmin;
    hsv.z = rgbmax;

    if (rgbmax == 0.0)
        hsv = vec3(0.0, 0.0, 0.0);
    else
        hsv.y = delta / rgbmax;

    vec3 c = (vec3(rgbmax, rgbmax, rgbmax) - rgb) / delta;
    if (hsv.y == 0.0)
        hsv.x = 0.0;
    else
        if (rgb.x == rgbmax)
            hsv.x = c.z - c.y;
        else if (rgb.y == rgbmax)
            hsv.x = 2.0 + c.x - c.z;
        else
            hsv.x = 4.0 + c.y - c.x;

        hsv.x /= 6.0;
        if (hsv.x < 0.0)
            hsv.x += 1.0;

    return hsv;
}
`)

const hsvToRgbNode = new THREENodes.FunctionNode(`
vec3 hsv2rgb(vec3 hsv) {
    vec3 rgb;
    float h = hsv.x;
    float s = hsv.y;
    float v = hsv.z;

    if (s == 0.0)
        rgb = vec3(v, v, v);
    else
        if (h == 1.0)
            h = 0.0;
        h *= 6.0;
        float i = floor(h);
        float f = h - i;
        rgb = vec3(f, f, f);
        float p = v * (1.0 - s);
        float q = v * (1.0 - s * f);
        float t = v * (1.0 - s * (1.0 - f));

        if (i == 0.0)
            rgb = vec3(v, t, p);
        else if (i == 1.0)
            rgb = vec3(q, v, p);
        else if (i == 2.0)
            rgb = vec3(p, v, t);
        else if (i == 3.0)
            rgb = vec3(p, q, v);
        else if (i == 4.0)
            rgb = vec3(t, p, v);
        else
            rgb = vec3(v, p, q);

    return rgb;
}
`)

export function threeRgbToHsvNode(rgb: THREENodes.Node) {
    return THREENodes.call(rgbToHsvNode, {rgb})
}

export function threeHsvToRgbNode(hsv: THREENodes.Node) {
    return THREENodes.call(hsvToRgbNode, {hsv})
}

export const lutSize = 256
//+ 0.5 / lutSize due to THREE.NearestFilter
const applyLut = new THREENodes.FunctionNode(`
vec4 applyLut(vec4 rgbaIn, sampler2D lut, float fac) {
    vec4 rgbaOut = vec4(
        texture(lut, vec2(rgbaIn.r + ${0.5 / lutSize}, 0.5)).r,
        texture(lut, vec2(rgbaIn.g + ${0.5 / lutSize}, 0.5)).g,
        texture(lut, vec2(rgbaIn.b + ${0.5 / lutSize}, 0.5)).b,
        texture(lut, vec2(rgbaIn.a + ${0.5 / lutSize}, 0.5)).a
    );
    return mix(rgbaIn, rgbaOut, fac);
}
`)

export class ApplyLUTNode extends THREENodes.TempNode {
    constructor(
        public rgbaInput: THREENodes.Node,
        public lutTexture: THREE.Texture,
        public fac: THREENodes.Node,
    ) {
        super("vec4")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        const lutTextureNode = THREENodes.convert(THREENodes.texture(this.lutTexture), "texture")
        const rgbOutput = THREENodes.call(applyLut, {rgbaIn: this.rgbaInput, lut: lutTextureNode, fac: this.fac})
        return rgbOutput.build(builder, type)
    }
}
