import {cyclesNode, DeclareMaterialNode, MaterialInputParameter, MaterialNode, MaterialSlot, materialSlots} from "@src/materials/declare-material-node"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {z} from "zod"
import {ImageResourceSchema} from "@src/materials/material-node-graph"
import {getAll, GetProperty, getProperty} from "@src/graph-system/utils"
import {OutputMaterial} from "@src/materials/nodes/output-material"
import {BsdfPrincipled} from "@src/materials/nodes/bsdf-principled"
import {TexImage} from "@src/materials/nodes/tex-image"
import {Mapping} from "@src/materials/nodes/mapping"
import {UVMap} from "@src/materials/nodes/uv-map"
import {CachedNodeGraphResult} from "@src/graph-system/evaluators/cached-node-graph-result"
import {Context} from "@src/materials/types"
import {HeightToNormal} from "@src/materials/nodes/height-to-normal"
import {NormalMap} from "@src/materials/nodes/normal-map"
import {BlendNormals} from "@src/materials/nodes/blend-normals"
import {TextureSet} from "@src/materials/nodes/texture-set"
import {ScannedTransmission} from "@src/materials/nodes/scanned-transmission"

const meshRenderSettings = z.object({
    displacementImageResource: ImageResourceSchema.optional(),
    displacementUvChannel: z.number().optional(),
    displacementMin: z.number().optional(),
    displacementMax: z.number().optional(),
    displacementNormalStrength: z.number().optional(),
    displacementNormalSmoothness: z.number().optional(),
    displacementNormalOriginalResolution: z.boolean().optional(),
})
type MeshRenderSettings = z.infer<typeof meshRenderSettings>

const applyMeshRenderSettingsReturns = z.object({
    surface: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
    volume: materialSlots.optional(),
    displacement: materialSlots.optional(),
})

const applyMeshRenderSettingsInputs = z.object({
    surface: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
    volume: materialSlots.optional(),
    displacement: materialSlots.optional(),
})

const applyMeshRenderSettingsParameters = z.object({
    meshRenderSettings: meshRenderSettings.optional(),
})

type ApplyMeshRenderSettingsFwd = MaterialNode<
    z.infer<typeof applyMeshRenderSettingsInputs>,
    z.infer<typeof applyMeshRenderSettingsReturns>,
    z.infer<typeof applyMeshRenderSettingsParameters>
>

export class ApplyMeshRenderSettings extends DeclareMaterialNode(
    {
        returns: applyMeshRenderSettingsReturns,
        inputs: applyMeshRenderSettingsInputs,
        parameters: applyMeshRenderSettingsParameters,
    },
    {
        toThree: async function (this: ApplyMeshRenderSettingsFwd, {get, inputs, parameters, context}) {
            const {volume, displacement} = await getAll(inputs, get)
            const {meshRenderSettings} = parameters

            if (!meshRenderSettings?.displacementImageResource) return {surface: await get(inputs.surface), volume, displacement}

            const clonedGraph = this.clone({cloneSubNode: () => true})

            const {surface} = clonedGraph.parameters
            if (!surface) throw new Error("ApplyMeshRenderSettings: surface input must be defined")

            const resolveBsdfPrincipled = (node: unknown, output: string | undefined): BsdfPrincipled => {
                if (node instanceof GetProperty) {
                    return resolveBsdfPrincipled(node.parameters.value, node.parameters.key)
                } else if (node instanceof BsdfPrincipled) {
                    if (output !== "bsdf") {
                        throw new Error("ApplyMeshRenderSettings: surface input must be linked to a bsdf material node output")
                    }
                    return node
                } else if (node instanceof ScannedTransmission) {
                    if (output !== "bsdf") {
                        throw new Error("ApplyMeshRenderSettings: surface input must be linked to a bsdf material node output")
                    }
                    return resolveBsdfPrincipled(node.parameters.bsdf, undefined)
                } else {
                    throw new Error("ApplyMeshRenderSettings: applied surface material must resolve to BsdfPrincipled")
                }
            }

            const bsdf = resolveBsdfPrincipled(surface, undefined)

            const {normal} = bsdf.parameters
            const {
                displacementImageResource,
                displacementUvChannel,
                displacementMin,
                displacementMax,
                displacementNormalStrength,
                displacementNormalSmoothness,
                displacementNormalOriginalResolution,
            } = meshRenderSettings

            //TODO: this doesn't account for primary and secondary UVs having different tangent spaces

            const color = getProperty(
                new HeightToNormal({
                    color: getProperty(
                        new TexImage({
                            vector: getProperty(
                                new Mapping({
                                    vector: getProperty(new UVMap({parameters: {uvMapIndex: displacementUvChannel}}), "uv"),
                                    parameters: {vectorType: "TEXTURE"},
                                }),
                                "vector",
                            ),
                            parameters: {
                                extension: "REPEAT",
                                interpolation: "Closest",
                                projection: "FLAT",
                                imageResource: displacementImageResource,
                                disableProgressiveLoading: true,
                                forceOriginalResolution: displacementNormalOriginalResolution,
                            },
                        }),
                        "color",
                    ),
                    parameters: {displacementMax, displacementMin, displacementNormalStrength, displacementNormalSmoothness},
                }),
                "color",
            )

            if (!normal)
                bsdf.updateParameters({
                    normal: getProperty(
                        new NormalMap({
                            color,
                            parameters: {strength: 1},
                        }),
                        "normal",
                    ),
                })
            else {
                if (!(normal instanceof GetProperty) || normal.parameters.key !== "normal")
                    throw new Error("ApplyMeshRenderSettings: normal input must be linked to a normal material node output")
                const {value: node} = normal.parameters

                const patchNormalMap = (normalMap: NormalMap) => {
                    const {color: originalColor} = normalMap.parameters

                    if (!originalColor)
                        normalMap.updateParameters({
                            color,
                        })
                    else {
                        normalMap.updateParameters({
                            color: getProperty(
                                new BlendNormals({
                                    color1: originalColor,
                                    color2: color,
                                    parameters: {},
                                }),
                                "color",
                            ),
                        })
                    }
                }

                if (node instanceof NormalMap) patchNormalMap(node)
                else if (node instanceof TextureSet) {
                    const {normal} = node.buildTextureSetGraph()

                    if (!(normal instanceof GetProperty) || normal.parameters.key !== "normal")
                        throw new Error("ApplyMeshRenderSettings: normal input must be linked to a normal material node output")

                    const {value: normalMap} = normal.parameters
                    if (!(normalMap instanceof NormalMap)) throw new Error("ApplyMeshRenderSettings: applied normal material must be a NormalMap")

                    patchNormalMap(normalMap)

                    bsdf.updateParameters({
                        normal: getProperty(normalMap, "normal"),
                    })
                } else throw new Error("ApplyMeshRenderSettings: applied normal material must be a NormalMap or a TextureSet")
            }

            return {surface: await compileNode<THREENodes.MeshPhysicalNodeMaterial>(getProperty(bsdf, "bsdf"), context), volume, displacement}
        },
    },
) {}

const compileNode = async <T extends MaterialSlot>(node: MaterialInputParameter<MaterialSlot>, context: Context): Promise<T> => {
    const result = new CachedNodeGraphResult(node, context)
    const compiled = await result.run()
    return compiled as T
}

export const applyMeshRenderSettings = (material: OutputMaterial, meshRenderSettings: MeshRenderSettings) => {
    if (!meshRenderSettings.displacementImageResource) return material

    const clonedMaterial = material.clone({cloneSubNode: () => true})

    const {surface, volume, displacement} = clonedMaterial.parameters
    clonedMaterial.updateParameters({surface: undefined, volume: undefined, displacement: undefined})

    const applyMeshRenderSettings = new ApplyMeshRenderSettings({
        surface,
        volume,
        displacement,
        parameters: {meshRenderSettings},
    })

    clonedMaterial.updateParameters({
        surface: getProperty(applyMeshRenderSettings, "surface"),
        volume: getProperty(applyMeshRenderSettings, "volume"),
        displacement: getProperty(applyMeshRenderSettings, "displacement"),
    })

    return clonedMaterial
}
