// @ts-strict-ignore
import * as THREENodes from "three/examples/jsm/nodes/Nodes"

const heightDerivToNormalMapFnNode = new THREENodes.FunctionNode(
    `
vec3 heightDerivToNormalMap(vec3 pt, vec2 uv, vec2 heightDeriv) {
    vec3 dpt_dx = vec3(dFdx(pt.x), dFdx(pt.y), dFdx(pt.z)); // Workaround for Adreno 3XX dFd*( vec3 ) bug. See #9988
    float sc = length(dFdx(uv)) / length(dpt_dx);
    vec2 D = heightDeriv * sc;
    return normalize(vec3(-D.x, -D.y, 1.0))*.5+.5;
}
`,
)

export class HeightDerivativeToNormalMapNode extends THREENodes.TempNode {
    private position: THREENodes.PositionNode

    constructor(
        public heightDeriv: THREENodes.Node,
        public uv: THREENodes.Node,
    ) {
        super("vec3")
        this.position = THREENodes.positionView
    }

    override generate(builder: THREENodes.NodeBuilder, output?: string | null) {
        const type = this.getNodeType(builder)

        if (builder.getShaderStage() == "fragment") {
            return THREENodes.call(heightDerivToNormalMapFnNode, {pt: this.position, uv: this.uv, heightDeriv: this.heightDeriv}).build(builder, type)
        } else {
            console.warn("HeightDerivativeToNormalMapNode is not compatible with non-fragment shader.")
            return builder.format("vec3( 0.0 )", type!, output as THREENodes.NodeTypeOption)
        }
    }
}

const derivativeSamplerNode = new THREENodes.FunctionNode(`
vec2 derivativeSampler(sampler2D tex, vec2 uv) {
    vec2 p0 = uv;
    vec2 p1 = uv + vec2(1e-3, 0.0);
    vec2 p2 = uv + vec2(0.0, 1e-3);
    float d0 = texture2D(tex, p0).r;
    float d1 = texture2D(tex, p1).r;
    float d2 = texture2D(tex, p2).r;
    return vec2(d1-d0, d2-d0) * 1e3;
}
`)

export class DerivativeSamplerNode extends THREENodes.TempNode {
    constructor(
        public tex: THREENodes.TextureNode,
        public uv: THREENodes.Node,
    ) {
        super("vec2")
    }

    override generate(builder: THREENodes.NodeBuilder, output?: string | null) {
        const type = this.getNodeType(builder)

        if (builder.getShaderStage() == "fragment") {
            return THREENodes.call(derivativeSamplerNode, {tex: THREENodes.convert(this.tex, "texture"), uv: this.uv}).build(builder, type)
        } else {
            console.warn("DerivativeSampler is not compatible with " + (builder as any).shader + " shader.")
            return builder.format("vec3( 0.0 )", type!, output as THREENodes.NodeTypeOption)
        }
    }
}

const viewPositionToUVSpaceViewDirectionNode = new THREENodes.FunctionNode(`
vec3 viewPosToDirection(vec3 eye_pos, vec3 surf_norm) {
    // Workaround for Adreno 3XX dFd*( vec3 ) bug. See #9988

    vec3 q0 = vec3( dFdx( eye_pos.x ), dFdx( eye_pos.y ), dFdx( eye_pos.z ) );
    vec3 q1 = vec3( dFdy( eye_pos.x ), dFdy( eye_pos.y ), dFdy( eye_pos.z ) );
    vec2 st0 = dFdx( vUv.st ) + vec2(1e-9,0.);
    vec2 st1 = dFdy( vUv.st ) + vec2(0.,1e-9);

    // float scale = sign( st1.t * st0.s - st0.t * st1.s ); // we do not care about the magnitude

    vec3 S = normalize( ( q0 * st1.t - q1 * st0.t ) ); // * scale );
    vec3 T = normalize( ( - q0 * st1.s + q1 * st0.s ) ); //* scale );
    vec3 N = normalize( surf_norm );

    mat3 tsn = mat3( S, T, N );
    return normalize( eye_pos * tsn );
}
`)

// alternate:
// vec3 N = normalize(surf_norm);
// vec3 dp2perp = cross( q1, N );
// vec3 dp1perp = cross( N, q0 );
// vec3 T = dp2perp * st0.x + dp1perp * st1.x;
// vec3 B = dp2perp * st0.y + dp1perp * st1.y;
// float invmax = inversesqrt( max( dot(T,T), dot(B,B) ) );
// mat3 tsn = mat3( T * invmax, B * invmax, N );
// return normalize( eye_pos * tsn );

const MAX_LAYERS = 16

// see https://apoorvaj.io/exploring-bump-mapping-with-webgl/
const parallaxFunctionNode = new THREENodes.FunctionNode(`
vec2 parallax_uv(sampler2D tex_height, vec2 uv, vec3 view_dir, vec2 uv_scale, float height_scale, float num_layers)
{
    float layer_height = 1.0 / num_layers;
    float cur_layer_height = 1.0;
    vec2 delta_uv = view_dir.xy * height_scale / (view_dir.z * num_layers);
    vec2 cur_uv = uv;
    vec2 prev_uv;

    float height_from_tex = texture2D(tex_height, cur_uv * uv_scale).r;
    float prev_height_from_tex;

    for (int i = 0; i < ${MAX_LAYERS}; i++) {
        cur_layer_height -= layer_height;
        prev_uv = cur_uv;
        cur_uv += delta_uv;
        prev_height_from_tex = height_from_tex;
        height_from_tex = texture2D(tex_height, cur_uv * uv_scale).r;
        if (height_from_tex > cur_layer_height) {
            break;
        }
    }

    // Parallax occlusion mapping
    float next = (height_from_tex - cur_layer_height);
    float prev = (prev_height_from_tex - cur_layer_height - layer_height);
    float weight = next / (next - prev);
    return mix(cur_uv, prev_uv, weight);
}
`)

export class UVDisplacementNode extends THREENodes.TempNode {
    constructor(
        public displacementTexture: THREENodes.TextureNode,
        public uv: THREENodes.UVNode,
        public mapWidth: number,
        public mapHeight: number,
        public depthScale: number,
        public numLayers: number = MAX_LAYERS,
    ) {
        super("vec2")
    }

    override generate(builder: THREENodes.NodeBuilder) {
        const type = this.getNodeType(builder)

        const normal = THREENodes.normalLocal
        const viewPosition = THREENodes.positionView

        const uvScaleX = THREENodes.uniform(1.0 / this.mapWidth)
        const uvScaleY = THREENodes.uniform(1.0 / this.mapHeight)
        const depthScale = THREENodes.uniform(this.depthScale)
        const numLayers = THREENodes.uniform(this.numLayers)

        return THREENodes.call(parallaxFunctionNode, {
            tex_height: THREENodes.convert(this.displacementTexture, "texture"),
            uv: this.uv,
            view_dir: THREENodes.call(viewPositionToUVSpaceViewDirectionNode, {eye_pos: viewPosition, surf_norm: normal}),
            uv_scale: new THREENodes.JoinNode([uvScaleX, uvScaleY]),
            height_scale: depthScale,
            num_layers: numLayers,
        }).build(builder, type)
    }
}
