import {DeclareMaterialNode, MaterialInputParameter, MaterialSlot, materialSlots, ThreeNode} from "#material-nodes/declare-material-node"
import {ShaderNode} from "#material-nodes/interfaces/shader-node"
import {ImageResourceSchema, ResolvedResource} from "#material-nodes/material-node-graph"
import {CombineRGB} from "#material-nodes/nodes/combine-rgb"
import {Displacement} from "#material-nodes/nodes/displacement"
import {Mapping} from "#material-nodes/nodes/mapping"
import {Math as MathNode} from "#material-nodes/nodes/math"
import {NormalMap} from "#material-nodes/nodes/normal-map"
import {RGB} from "#material-nodes/nodes/rgb"
import {RGBCurve} from "#material-nodes/nodes/rgb-curve"
import {RGBToBW} from "#material-nodes/nodes/rgb-to-bw"
import {SeparateHSV} from "#material-nodes/nodes/separate-hsv"
import {SeparateRGB} from "#material-nodes/nodes/separate-rgb"
import {SeparateXYZ} from "#material-nodes/nodes/separate-xyz"
import {TexCoord} from "#material-nodes/nodes/tex-coord"
import {TexImage} from "#material-nodes/nodes/tex-image"
import {Value} from "#material-nodes/nodes/value"
import {VectorMath} from "#material-nodes/nodes/vector-math"
import {Context, TextureType} from "#material-nodes/types"
import {CachedNodeGraphResult} from "@cm/graph/evaluators/cached-node-graph-result"
import {GetGraphParamTypes} from "@cm/graph/node-graph"
import {getProperty} from "@cm/graph/utils"
import {mapFields, promiseAllProperties} from "@cm/utils"
import {z} from "zod"

type TextureSetGraph = {
    baseColor: MaterialInputParameter<MaterialSlot>
    metallic: MaterialInputParameter<MaterialSlot>
    specular: MaterialInputParameter<MaterialSlot>
    roughness: MaterialInputParameter<MaterialSlot>
    anisotropic: MaterialInputParameter<MaterialSlot>
    anisotropicRotation: MaterialInputParameter<MaterialSlot>
    alpha: MaterialInputParameter<MaterialSlot>
    normal: MaterialInputParameter<MaterialSlot>
    displacement: MaterialInputParameter<MaterialSlot>
    transmission: MaterialInputParameter<MaterialSlot>
}

export class TextureSet extends DeclareMaterialNode(
    {
        returns: z.object({
            baseColor: materialSlots,
            metallic: materialSlots,
            specular: materialSlots,
            roughness: materialSlots,
            anisotropic: materialSlots,
            anisotropicRotation: materialSlots,
            alpha: materialSlots,
            normal: materialSlots,
            displacement: materialSlots,
            transmission: materialSlots,
        }),
        inputs: z.object({uv: materialSlots.optional(), normalStrength: materialSlots.optional()}),
        parameters: z.object({
            mapAssignmentAnisotropyImageResourceSlot: z.number().optional(),
            mapAssignmentDiffuseImageResourceSlot: z.number().optional(),
            mapAssignmentDisplacementImageResourceSlot: z.number().optional(),
            mapAssignmentMetalnessImageResourceSlot: z.number().optional(),
            mapAssignmentNormalImageResourceSlot: z.number().optional(),
            mapAssignmentRoughnessImageResourceSlot: z.number().optional(),
            mapAssignmentSpecularStrengthImageResourceSlot: z.number().optional(),
            normalStrength: z.number().optional(),
            normalRotation: z.number().optional(),
            transmissionMin: z.number().optional(),
            transmissionMax: z.number().optional(),
            textureSetRevisionId: z.string().optional(),
            imageResources: z.array(ImageResourceSchema).optional(),
        }),
    },
    {
        toThree: async function (
            this: {
                buildTextureSetGraph: () => TextureSetGraph
            },
            {context},
        ) {
            return promiseAllProperties(mapFields(this.buildTextureSetGraph(), (value) => compileNode<ThreeNode>(value, context)))
        },

        toCycles: async function (
            this: {
                buildTextureSetGraph: () => TextureSetGraph
            },
            {context},
        ) {
            return promiseAllProperties(mapFields(this.buildTextureSetGraph(), (value) => compileNode<ShaderNode>(value, context)))
        },
    },
) {
    buildTextureSetGraph(): TextureSetGraph {
        const {parameters, ...inputs} = this.parameters
        const {textureSetRevisionId, normalStrength} = parameters
        const {widthCm, heightCm, displacementCm} = parameters.imageResources?.[0]?.metadata ?? {}
        const inNormalStrength = inputs.normalStrength ?? createConstantValue(normalStrength ?? 1)

        const createDefaultUV = () => getProperty(new TexCoord({parameters: {}}), "uv")
        const createDefaultBaseColor = () => createConstantColor([0.5, 0.5, 0.5])
        const createDefaultMetallic = () => createConstantValue(0)
        const createDefaultSpecular = () => createConstantValue(0)
        const createDefaultRoughness = () => createConstantValue(1)
        const createDefaultAnisotropic = () => createConstantValue(0)
        const createDefaultAnisotropicRotation = () => createConstantValue(0)
        const createDefaultAlpha = () => createConstantValue(1)
        const createDefaultNormal = () =>
            getProperty(new NormalMap({strength: inNormalStrength, color: createConstantColor([0.5, 0.5, 1]), parameters: {}}), "normal")
        const createDefaultDisplacement = () => createConstantValue(0)
        const createDefaultTransmission = () => createConstantColor([0, 0, 0])

        if (textureSetRevisionId === undefined || widthCm === undefined || heightCm === undefined) {
            return {
                baseColor: createDefaultBaseColor(),
                metallic: createDefaultMetallic(),
                specular: createDefaultSpecular(),
                roughness: createDefaultRoughness(),
                anisotropic: createDefaultAnisotropic(),
                anisotropicRotation: createDefaultAnisotropicRotation(),
                alpha: createDefaultAlpha(),
                normal: createDefaultNormal(),
                displacement: createDefaultDisplacement(),
                transmission: createDefaultTransmission(),
            }
        }

        const {uv} = inputs
        const mapping = createMappingNode(uv ?? createDefaultUV(), [widthCm, heightCm, 1])

        const createTexImageForTextureType = (textureType: TextureType) => {
            const imageResource = parameters.imageResources?.find((resource) => resource.metadata?.textureType === textureType)
            if (!imageResource) return undefined

            return createTexImageNodeForNodeImageResourceSlot(imageResource, mapping)
        }

        const createAnisotropicFromAnisotrophy = (colorMassAnisotrophy: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!colorMassAnisotrophy) {
                return undefined
            }
            return getProperty(
                new VectorMath({
                    vector: getProperty(
                        new VectorMath({
                            vector: getProperty(
                                new VectorMath({
                                    vector: colorMassAnisotrophy,
                                    parameters: {
                                        operation: "ADD",
                                        vector_001: {x: -0.5, y: -0.5, z: 0},
                                    },
                                }),
                                "vector",
                            ),
                            parameters: {
                                operation: "MULTIPLY",
                                vector_001: {x: 2, y: -2, z: 0},
                            },
                        }),
                        "vector",
                    ),
                    parameters: {
                        operation: "LENGTH",
                    },
                }),
                "vector",
            )
        }

        const createAnisotropicRotationFromAnisotrophy = (colorMassAnisotrophy: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!colorMassAnisotrophy) {
                return undefined
            }

            const xyz = new SeparateXYZ({
                vector: getProperty(
                    new VectorMath({
                        vector: getProperty(
                            new VectorMath({
                                vector: colorMassAnisotrophy,
                                parameters: {
                                    operation: "ADD",
                                    vector_001: {x: -0.5, y: -0.5, z: 0},
                                },
                            }),
                            "vector",
                        ),
                        parameters: {
                            operation: "MULTIPLY",
                            vector_001: {x: 2, y: -2, z: 0},
                        },
                    }),
                    "vector",
                ),
                parameters: {},
            })

            const angle = getProperty(
                new MathNode({
                    value: getProperty(xyz, "y"),
                    value_001: getProperty(xyz, "x"),
                    parameters: {operation: "ARCTAN2"},
                }),
                "value",
            )

            return getProperty(
                new MathNode({
                    value: getProperty(
                        new MathNode({
                            value: getProperty(
                                new MathNode({
                                    value: getProperty(
                                        new MathNode({
                                            value: angle,
                                            parameters: {operation: "GREATER_THAN", value_001: 0},
                                        }),
                                        "value",
                                    ),
                                    value_001: angle,
                                    parameters: {operation: "MULTIPLY"},
                                }),
                                "value",
                            ),
                            value_001: getProperty(
                                new MathNode({
                                    value: getProperty(
                                        new MathNode({
                                            value: angle,
                                            parameters: {operation: "LESS_THAN", value_001: 0},
                                        }),
                                        "value",
                                    ),
                                    value_001: getProperty(
                                        new MathNode({
                                            value: angle,
                                            parameters: {operation: "ADD", value_001: Math.PI * 2},
                                        }),
                                        "value",
                                    ),
                                    parameters: {operation: "MULTIPLY"},
                                }),
                                "value",
                            ),
                            parameters: {operation: "ADD"},
                        }),
                        "value",
                    ),
                    parameters: {operation: "DIVIDE", value_001: Math.PI * 2},
                }),
                "value",
            )
        }

        const convertNormal = (normalMap: MaterialInputParameter<MaterialSlot> | undefined, normalRotation: number) => {
            if (!normalMap) {
                return undefined
            }

            return getProperty(
                new NormalMap({
                    color: rotateNormalMapVector(
                        getProperty(
                            new RGBCurve({
                                color: normalMap,
                                parameters: {
                                    fac: 1,
                                    controlPoints: [
                                        [
                                            {x: 0, y: 0},
                                            {x: 1, y: 1},
                                        ], // R
                                        [
                                            {x: 0, y: 1},
                                            {x: 1, y: 0},
                                        ], // G (inverted)
                                        [
                                            {x: 0, y: 0},
                                            {x: 1, y: 1},
                                        ], // B
                                        [
                                            {x: 0, y: 0},
                                            {x: 1, y: 1},
                                        ], // RGB
                                    ],
                                },
                            }),
                            "color",
                        ),
                        normalRotation,
                    ),
                    strength: inNormalStrength,
                    parameters: {},
                }),
                "normal",
            )
        }

        const convertDisplacement = (displacementMap: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!displacementMap) {
                return undefined
            }

            return getProperty(
                new Displacement({
                    height: displacementMap,
                    scale: getProperty(
                        new MathNode({
                            value: inNormalStrength,
                            parameters: {
                                value_001: displacementCm ?? 0,
                                operation: "MULTIPLY",
                            },
                        }),
                        "value",
                    ),
                    parameters: {midlevel: 0.5},
                }),
                "displacement",
            )
        }

        const createMetallicFromF0 = (colormassF0: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!colormassF0) {
                return undefined
            }

            return getProperty(
                new SeparateRGB({
                    image: colormassF0,
                    parameters: {},
                }),
                "g",
            )
        }

        const createSpecularFromF0 = (colormassF0: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!colormassF0) {
                return undefined
            }

            return getProperty(
                new SeparateRGB({
                    image: colormassF0,
                    parameters: {},
                }),
                "r",
            )
        }

        const createAnisotropicFromAnisotrophyStrength = (colorMassAnisotrophyStrength: MaterialInputParameter<MaterialSlot> | undefined) => {
            if (!colorMassAnisotrophyStrength) {
                return undefined
            }

            return getProperty(
                new MathNode({
                    value: getProperty(new MathNode({value: colorMassAnisotrophyStrength, parameters: {operation: "ADD", value_001: -0.5}}), "value"),
                    parameters: {
                        operation: "MULTIPLY",
                        value_001: 2,
                    },
                }),
                "value",
            )
        }

        const createAlphaFromTransmission = (
            transmissionMap: MaterialInputParameter<MaterialSlot> | undefined,
            minValue: number,
            maxValue: number,
            useLuminance = false,
        ) => {
            if (!transmissionMap) {
                return undefined
            }
            // rescale and invert ([minValue...maxValue] -> [1...0])
            return getProperty(
                new MathNode({
                    value: getProperty(
                        new MathNode({
                            value: useLuminance
                                ? getProperty(
                                      new RGBToBW({
                                          color: transmissionMap,
                                          parameters: {},
                                      }),
                                      "value",
                                  )
                                : getProperty(
                                      new SeparateHSV({
                                          image: transmissionMap,
                                          parameters: {},
                                      }),
                                      "v",
                                  ),
                            parameters: {
                                operation: "ADD",
                                value_001: -maxValue,
                                useClamp: false,
                            },
                        }),
                        "value",
                    ),
                    parameters: {
                        operation: "MULTIPLY",
                        value_001: 1 / (minValue - maxValue),
                        useClamp: true,
                    },
                }),
                "value",
            )
        }

        const transmissionMin = parameters.transmissionMin ?? 0
        const transmissionMax = parameters.transmissionMax ?? 1

        // current set
        const sourceMapDiffuse = createTexImageForTextureType(TextureType.Diffuse)
        const sourceMapMetalness = createTexImageForTextureType(TextureType.Metalness)
        const sourceMapSpecularStrength = createTexImageForTextureType(TextureType.SpecularStrength)
        const sourceMapRoughness = createTexImageForTextureType(TextureType.Roughness)
        const sourceMapAnisotrophy = createTexImageForTextureType(TextureType.Anisotropy)
        const sourceMapNormal = createTexImageForTextureType(TextureType.Normal)
        const sourceMapDisplacement = createTexImageForTextureType(TextureType.Displacement)
        // legacy set
        const sourceMapF0 = createTexImageForTextureType(TextureType.F0)
        const sourceMapAnisotrophyStrength = createTexImageForTextureType(TextureType.AnisotropyStrength)
        const sourceMapAnisotrophyRotation = createTexImageForTextureType(TextureType.AnisotropyRotation)
        // special maps
        const sourceMapMask = createTexImageForTextureType(TextureType.Mask)
        const sourceMapTransmission = createTexImageForTextureType(TextureType.Transmission)

        const normalRotation = parameters.normalRotation ?? 0

        // create the nodes
        const baseColor = sourceMapDiffuse ?? createDefaultBaseColor()
        const metallic = sourceMapMetalness ?? createMetallicFromF0(sourceMapF0) ?? createDefaultMetallic()
        const specular = sourceMapSpecularStrength ?? createSpecularFromF0(sourceMapF0) ?? createDefaultSpecular()
        const roughness = sourceMapRoughness ?? createDefaultRoughness()
        const anisotropic =
            createAnisotropicFromAnisotrophy(sourceMapAnisotrophy) ??
            createAnisotropicFromAnisotrophyStrength(sourceMapAnisotrophyStrength) ??
            createDefaultAnisotropic()
        const anisotropicRotation =
            createAnisotropicRotationFromAnisotrophy(sourceMapAnisotrophy) ?? sourceMapAnisotrophyRotation ?? createDefaultAnisotropicRotation()
        const alpha = sourceMapMask ?? createAlphaFromTransmission(sourceMapTransmission, transmissionMin, transmissionMax) ?? createDefaultAlpha()
        const normal = convertNormal(sourceMapNormal, normalRotation) ?? createDefaultNormal()
        const displacement = convertDisplacement(sourceMapDisplacement) ?? createDefaultDisplacement()
        const transmission = sourceMapTransmission ?? createDefaultTransmission()

        return {
            baseColor,
            metallic,
            specular,
            roughness,
            anisotropic,
            anisotropicRotation,
            alpha,
            normal,
            displacement,
            transmission,
        }
    }

    hasAlpha() {
        const {parameters} = this.parameters
        const {textureSetRevisionId} = parameters
        const {widthCm, heightCm} = parameters.imageResources?.[0]?.metadata ?? {}
        if (textureSetRevisionId === undefined || widthCm === undefined || heightCm === undefined) return false

        const hasImageForTextureType = (textureType: TextureType) => {
            const imageResource = parameters.imageResources?.find((resource) => resource.metadata?.textureType === textureType)
            return !!imageResource
        }

        return hasImageForTextureType(TextureType.Mask) || hasImageForTextureType(TextureType.Transmission)
    }
}

function createConstantColor(color: [number, number, number]) {
    return getProperty(new RGB({parameters: {color: {r: color[0], g: color[1], b: color[2]}}}), "color")
}

function createConstantValue(value: number) {
    return getProperty(new Value({parameters: {value: value}}), "value")
}

function createMappingNode(
    uvNode: MaterialInputParameter<MaterialSlot | undefined>,
    scale: [number, number, number] = [1, 1, 1],
    rotation: [number, number, number] = [0, 0, 0],
    location: [number, number, number] = [0, 0, 0],
) {
    return getProperty(
        new Mapping({
            vector: uvNode,
            parameters: {
                vectorType: "TEXTURE",
                location: {x: location[0], y: location[1], z: location[2]},
                rotation: {x: rotation[0], y: rotation[1], z: rotation[2]},
                scale: {x: scale[0], y: scale[1], z: scale[2]},
            },
        }),
        "vector",
    )
}

function createTexImageNodeForNodeImageResourceSlot(imageResource: ResolvedResource, mappingNode: MaterialInputParameter<MaterialSlot>) {
    return getProperty(
        new TexImage({vector: mappingNode, parameters: {extension: "REPEAT", interpolation: "Closest", projection: "FLAT", imageResource}}),
        "color",
    )
}

function separateRGB(v: MaterialInputParameter<MaterialSlot>) {
    const node = new SeparateRGB({image: v, parameters: {}})
    return [getProperty(node, "r"), getProperty(node, "g"), getProperty(node, "b")]
}

function combineRGB(r: MaterialInputParameter<MaterialSlot>, g: MaterialInputParameter<MaterialSlot>, b: MaterialInputParameter<MaterialSlot>) {
    const node = new CombineRGB({r, g, b, parameters: {}})
    return getProperty(node, "image")
}

function scalarMathOp(
    a: MaterialInputParameter<MaterialSlot> | number,
    b: MaterialInputParameter<MaterialSlot> | number,
    op: GetGraphParamTypes<MathNode>["parameters"]["operation"],
) {
    const getInput = (v: MaterialInputParameter<MaterialSlot> | number) => {
        return typeof v === "number" ? createConstantValue(v) : v
    }

    const node = new MathNode({value: getInput(a), value_001: getInput(b), parameters: {operation: op}})
    return getProperty(node, "value")
}

function sadd(a: MaterialInputParameter<MaterialSlot> | number, b: MaterialInputParameter<MaterialSlot> | number) {
    return scalarMathOp(a, b, "ADD")
}

function ssub(a: MaterialInputParameter<MaterialSlot> | number, b: MaterialInputParameter<MaterialSlot> | number) {
    return scalarMathOp(a, b, "SUBTRACT")
}

function smul(a: MaterialInputParameter<MaterialSlot> | number, b: MaterialInputParameter<MaterialSlot> | number) {
    if (a === 1.0) return b
    else if (a === 0.0) return 0.0
    else if (b === 1.0) return a
    else if (b === 0.0) return 0.0
    return scalarMathOp(a, b, "MULTIPLY")
}

function rotateNormalMapVector(normalMap: MaterialInputParameter<MaterialSlot>, angleDegrees: number) {
    // increasing angleDegrees will rotate the vector CCW around the Z axis
    if (angleDegrees == 0) return normalMap
    const theta = (angleDegrees * Math.PI) / 180
    const [nr, ng, nb] = separateRGB(normalMap)
    const nx = ssub(nr, 0.5)
    const ny = ssub(ng, 0.5)
    const s = Math.sin(theta)
    const c = Math.cos(theta)
    return combineRGB(sadd(sadd(smul(nx, c), smul(ny, s)), 0.5), sadd(ssub(smul(ny, c), smul(nx, s)), 0.5), nb)
}

const compileNode = async <T extends MaterialSlot>(node: MaterialInputParameter<MaterialSlot>, context: Context): Promise<T> => {
    const result = new CachedNodeGraphResult(node, context)
    const compiled = await result.run()
    return compiled as T
}
