import {DeclareMaterialNode, DeclareMaterialNodeType, materialSlots} from "#material-nodes/declare-material-node"
import {threeValueNode} from "#material-nodes/three-utils"
import {getAll} from "@cm/graph/utils"
import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"
import {z} from "zod"

const heightToNormalCommon = new THREENodes.CodeNode(`
vec2 derivativeSampler(sampler2D tex, float dx, float dy, vec2 uv, float scale) {
    vec2 px2 = uv + vec2(dx, 0.0);
    vec2 px1 = uv - vec2(dx, 0.0);

    vec2 py2 = uv + vec2(0.0, dy);
    vec2 py1 = uv - vec2(0.0, dy);

    float dx2 = texture2D(tex, px2).r;
    float dx1 = texture2D(tex, px1).r;

    float dy2 = texture2D(tex, py2).r;
    float dy1 = texture2D(tex, py1).r;

    return vec2((dx2-dx1)/(2.0*dx), (dy2-dy1)/(2.0*dy)) * scale;
}
`)

const heightToNormal = new THREENodes.FunctionNode(
    `
vec3 heightDerivToNormal(vec3 pt, vec2 uv, sampler2D tex, float dx, float dy, float scale) {
    vec3 dpt_dx = vec3(dFdx(pt.x), dFdx(pt.y), dFdx(pt.z)); // Workaround for Adreno 3XX dFd*( vec3 ) bug. See #9988
    float sc = length(dFdx(uv)) / length(dpt_dx);
    vec2 D = derivativeSampler(tex, dx, dy, uv, scale) * sc;
    return normalize(vec3(-D.x, -D.y, 1.0))*.5+.5;
}
`,
    [heightToNormalCommon],
)

class HeightToNormalNode extends THREENodes.TempNode {
    constructor(
        public color: THREENodes.TextureNode,
        public scale: THREENodes.Node,
        public stepSize: number,
    ) {
        super("vec3")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        const {value: texture} = this.color

        const width = texture.image?.width
        const height = texture.image?.height

        if (typeof width !== "number" || typeof height !== "number") throw new Error("Height to normal node requires texture with defined width and height")

        const colorTextureNode = THREENodes.convert(THREENodes.texture(texture), "texture")

        return THREENodes.call(heightToNormal, {
            pt: THREENodes.positionView,
            uv: this.color.uvNode ?? new THREENodes.UVNode(0),
            tex: colorTextureNode,
            dx: THREENodes.float(this.stepSize / width),
            dy: THREENodes.float(this.stepSize / height),
            scale: this.scale,
        }).build(builder, type)
    }
}

const ReturnTypeSchema = z.object({color: materialSlots.optional()})
const InputTypeSchema = z.object({color: materialSlots.optional()})
const ParametersTypeSchema = z.object({
    displacementMax: z.number().optional(),
    displacementMin: z.number().optional(),
    displacementNormalStrength: z.number().optional(),
    displacementNormalSmoothness: z.number().optional(),
})
export class HeightToNormal extends (DeclareMaterialNode(
    {
        returns: ReturnTypeSchema,
        inputs: InputTypeSchema,
        parameters: ParametersTypeSchema,
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const {color} = await getAll(inputs, get)
            const {displacementMax, displacementMin, displacementNormalStrength, displacementNormalSmoothness} = parameters

            const scaleFromDisplacement = displacementMax !== undefined && displacementMin !== undefined ? displacementMax - displacementMin : 1

            const scale = threeValueNode(displacementNormalStrength ?? scaleFromDisplacement)

            if (!(color instanceof THREENodes.TextureNode)) {
                console.warn("Height to normal node cannot use non-texture node as input")
                return {color: undefined}
            }

            return {color: new HeightToNormalNode(color, scale, displacementNormalSmoothness ?? 1)}
        },
    },
) as DeclareMaterialNodeType<typeof ReturnTypeSchema, typeof InputTypeSchema, typeof ParametersTypeSchema>) {}
