import {
    cyclesNode,
    DeclareMaterialNode,
    DeclareMaterialNodeType,
    MaterialInputParameter,
    MaterialNode,
    MaterialSlot,
    materialSlots,
} from "#material-nodes/declare-material-node"
import {LegacyMaterialConverter} from "#material-nodes/legacy-material-converter"
import {ImageNodeGenerator, isImageNodeGenerator, wrapNodeOutput} from "#material-nodes/material-node-graph"
import {BlendNormals} from "#material-nodes/nodes/blend-normals"
import {BsdfPrincipled} from "#material-nodes/nodes/bsdf-principled"
import {HeightToNormal} from "#material-nodes/nodes/height-to-normal"
import {NormalMap} from "#material-nodes/nodes/normal-map"
import {OutputMaterial} from "#material-nodes/nodes/output-material"
import {ScannedTransmission} from "#material-nodes/nodes/scanned-transmission"
import {TexImage} from "#material-nodes/nodes/tex-image"
import {TextureSet} from "#material-nodes/nodes/texture-set"
import {Context} from "#material-nodes/types"
import {NodeGraph} from "@cm/graph"
import {CachedNodeGraphResult} from "@cm/graph/evaluators/cached-node-graph-result"
import {getAll, GetProperty, getProperty} from "@cm/graph/utils"
import * as THREE from "three"
import * as THREENodes from "three/examples/jsm/nodes/Nodes.js"
import {z} from "zod"

export const ImageNodeGeneratorSchema = z.any().superRefine((arg, ctx): arg is ImageNodeGenerator => {
    if (!isImageNodeGenerator(arg)) {
        ctx.addIssue({
            code: z.ZodIssueCode.custom,
            message: "Invalid image node generator",
            fatal: true,
        })
    }
    return z.NEVER
}) as z.ZodEffects<z.ZodAny, ImageNodeGenerator, ImageNodeGenerator>

export const imageGenerator = z.object({
    imageNode: ImageNodeGeneratorSchema,
    metadata: z.object({width: z.number().optional(), height: z.number().optional(), legacyId: z.number().optional()}).optional(),
})

export type ImageGenerator = z.infer<typeof imageGenerator>

const meshRenderSettings = z.object({
    displacementImage: imageGenerator.optional(),
    displacementUvChannel: z.number().optional(),
    displacementMin: z.number().optional(),
    displacementMax: z.number().optional(),
    displacementNormalStrength: z.number().optional(),
    displacementNormalSmoothness: z.number().optional(),
    displacementNormalOriginalResolution: z.boolean().optional(),
})
type MeshRenderSettings = z.infer<typeof meshRenderSettings>

const applyMeshRenderSettingsReturns = z.object({
    surface: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
    volume: materialSlots.optional(),
    displacement: materialSlots.optional(),
})

const applyMeshRenderSettingsInputs = z.object({
    surface: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode).optional(),
    volume: materialSlots.optional(),
    displacement: materialSlots.optional(),
})

const applyMeshRenderSettingsParameters = z.object({
    meshRenderSettings: meshRenderSettings.optional(),
})

type ApplyMeshRenderSettingsFwd = MaterialNode<
    z.infer<typeof applyMeshRenderSettingsInputs>,
    z.infer<typeof applyMeshRenderSettingsReturns>,
    z.infer<typeof applyMeshRenderSettingsParameters>
>

class ProvideThreeTextureNode extends DeclareMaterialNode(
    {
        returns: z.object({color: materialSlots}),
        inputs: z.object({}),
        parameters: z.object({color: z.instanceof(THREENodes.Node)}),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            return {color: parameters.color}
        },
    },
) {}

export class ApplyMeshRenderSettings extends (DeclareMaterialNode(
    {
        returns: applyMeshRenderSettingsReturns,
        inputs: applyMeshRenderSettingsInputs,
        parameters: applyMeshRenderSettingsParameters,
    },
    {
        toThree: async function (this: ApplyMeshRenderSettingsFwd, {get, inputs, parameters, context}) {
            const {volume, displacement} = await getAll(inputs, get)
            const {meshRenderSettings} = parameters

            if (!meshRenderSettings?.displacementImage) return {surface: await get(inputs.surface), volume, displacement}

            const clonedGraph = this.clone({cloneSubNode: () => true})

            const {surface} = clonedGraph.parameters
            if (!surface) throw new Error("ApplyMeshRenderSettings: surface input must be defined")

            const resolveBsdfPrincipled = (node: unknown, output: string | undefined): BsdfPrincipled => {
                if (node instanceof GetProperty) {
                    return resolveBsdfPrincipled(node.parameters.value, node.parameters.key)
                } else if (node instanceof BsdfPrincipled) {
                    if (output !== "bsdf") {
                        throw new Error("ApplyMeshRenderSettings: surface input must be linked to a bsdf material node output")
                    }
                    return node
                } else if (node instanceof ScannedTransmission) {
                    if (output !== "bsdf") {
                        throw new Error("ApplyMeshRenderSettings: surface input must be linked to a bsdf material node output")
                    }
                    return resolveBsdfPrincipled(node.parameters.bsdf, undefined)
                } else {
                    throw new Error("ApplyMeshRenderSettings: applied surface material must resolve to BsdfPrincipled")
                }
            }

            const bsdf = resolveBsdfPrincipled(surface, undefined)

            const {normal} = bsdf.parameters
            const {
                displacementImage,
                displacementUvChannel,
                displacementMin,
                displacementMax,
                displacementNormalStrength,
                displacementNormalSmoothness,
                displacementNormalOriginalResolution,
            } = meshRenderSettings

            //TODO: this doesn't account for primary and secondary UVs having different tangent spaces

            const getDisplacementTexture = async (): Promise<NodeGraph<MaterialSlot, Context>> => {
                const materialConverter = new LegacyMaterialConverter()

                const uv = wrapNodeOutput<"UVMap">(
                    {
                        nodeType: "UVMap",
                        parameters: {
                            "internal.uv_map_index": displacementUvChannel,
                        },
                    },
                    "UV",
                )
                const vec = wrapNodeOutput<"Mapping">(
                    {
                        nodeType: "Mapping",
                        inputs: {
                            Vector: uv,
                        },
                        parameters: {
                            "internal.vector_type": "TEXTURE",
                        },
                    },
                    "Vector",
                )

                const displacementImageTest = materialConverter.convertWrappedNode(
                    displacementImage.imageNode.generator({
                        uv: vec,
                        extension: "REPEAT",
                        interpolation: "Closest",
                        projection: "FLAT",
                        disableProgressiveLoading: true,
                        forceOriginalResolution: displacementNormalOriginalResolution,
                    }).color,
                )

                //Case 1: Displacement is provided as a texture, we do not need to bake it
                if (displacementImageTest instanceof GetProperty && displacementImageTest.parameters.key === "color") {
                    if (displacementImageTest.parameters.value instanceof TexImage) return displacementImageTest as NodeGraph<MaterialSlot, Context>
                }

                //Case 2: Displacement is provided as a node graph, we need to bake it
                const convertedDisplacementImage = materialConverter.convertWrappedNode(
                    displacementImage.imageNode.generator({
                        uv: wrapNodeOutput<"UVMap">(
                            {
                                nodeType: "UVMap",
                                parameters: {
                                    "internal.uv_map_index": 0,
                                },
                            },
                            "UV",
                        ),
                        extension: "REPEAT",
                        interpolation: "Closest",
                        projection: "FLAT",
                        disableProgressiveLoading: true,
                        forceOriginalResolution: displacementNormalOriginalResolution,
                    }).color,
                )

                const {threeRenderer, threeDefaultFloatTextureType} = context
                if (!threeRenderer || !threeDefaultFloatTextureType)
                    throw Error("ApplyMeshRenderSettings: threeRenderer and threeDefaultFloatTextureType must be defined in context")

                const bakedMaterial = new THREENodes.MeshBasicNodeMaterial()
                bakedMaterial.colorNode = await compileNode<THREENodes.Node>(convertedDisplacementImage, context)

                const getResolution = (originalResolution: number | undefined) => {
                    if (displacementNormalOriginalResolution || context.textureResolution === "original") return originalResolution
                    else
                        switch (context.textureResolution) {
                            case "500px":
                                return 500
                            case "1000px":
                                return 1000
                            case "2000px":
                                return 2000
                        }
                    return undefined
                }

                const width = getResolution(displacementImage.metadata?.width) ?? 2048
                const height = getResolution(displacementImage.metadata?.height) ?? 2048

                const getFilter = () => {
                    if (context.forceFiltering === "nearest") return THREE.NearestFilter
                    else return THREE.LinearFilter
                }

                const renderTarget = new THREE.WebGLRenderTarget(width, height, {
                    type: threeDefaultFloatTextureType,
                    minFilter: getFilter(),
                    magFilter: getFilter(),
                    format: THREE.RedFormat,
                    colorSpace: THREE.NoColorSpace,
                })

                context.onThreeCreatedTexture?.(renderTarget)

                const renderScene = new THREE.Scene()

                const planeMesh = new THREE.Mesh(new THREE.PlaneGeometry(1.0, 1.0), bakedMaterial)
                renderScene.add(planeMesh)

                const renderCamera = new THREE.OrthographicCamera(-0.5, 0.5, 0.5, -0.5, 0.1, 2)
                renderCamera.position.z = 1

                const previousRenderTarget = threeRenderer.getRenderTarget()
                threeRenderer.setRenderTarget(renderTarget)
                threeRenderer.compile(renderScene, renderCamera)
                threeRenderer.render(renderScene, renderCamera)

                threeRenderer.setRenderTarget(previousRenderTarget)

                planeMesh.geometry.dispose()
                planeMesh.material.dispose()

                const textureNode = THREENodes.texture(
                    renderTarget.texture,
                    await compileNode<THREENodes.Node>(materialConverter.convertWrappedNode(vec), context),
                )

                return getProperty(new ProvideThreeTextureNode({parameters: {color: textureNode}}), "color")
            }

            const color = getProperty(
                new HeightToNormal({
                    color: await getDisplacementTexture(),
                    parameters: {displacementMax, displacementMin, displacementNormalStrength, displacementNormalSmoothness},
                }),
                "color",
            )

            if (!normal)
                bsdf.updateParameters({
                    normal: getProperty(
                        new NormalMap({
                            color,
                            parameters: {strength: 1},
                        }),
                        "normal",
                    ),
                })
            else {
                if (!(normal instanceof GetProperty) || normal.parameters.key !== "normal")
                    throw new Error("ApplyMeshRenderSettings: normal input must be linked to a normal material node output")
                const {value: node} = normal.parameters

                const patchNormalMap = (normalMap: NormalMap) => {
                    const {color: originalColor} = normalMap.parameters

                    if (!originalColor)
                        normalMap.updateParameters({
                            color,
                        })
                    else {
                        normalMap.updateParameters({
                            color: getProperty(
                                new BlendNormals({
                                    color1: originalColor,
                                    color2: color,
                                    parameters: {},
                                }),
                                "color",
                            ),
                        })
                    }
                }

                if (node instanceof NormalMap) patchNormalMap(node)
                else if (node instanceof TextureSet) {
                    const {normal} = node.buildTextureSetGraph()

                    if (!(normal instanceof GetProperty) || normal.parameters.key !== "normal")
                        throw new Error("ApplyMeshRenderSettings: normal input must be linked to a normal material node output")

                    const {value: normalMap} = normal.parameters
                    if (!(normalMap instanceof NormalMap)) throw new Error("ApplyMeshRenderSettings: applied normal material must be a NormalMap")

                    patchNormalMap(normalMap)

                    bsdf.updateParameters({
                        normal: getProperty(normalMap, "normal"),
                    })
                } else throw new Error("ApplyMeshRenderSettings: applied normal material must be a NormalMap or a TextureSet")
            }

            return {surface: await compileNode<THREENodes.MeshPhysicalNodeMaterial>(getProperty(bsdf, "bsdf"), context), volume, displacement}
        },
    },
) as DeclareMaterialNodeType<typeof applyMeshRenderSettingsReturns, typeof applyMeshRenderSettingsInputs, typeof applyMeshRenderSettingsParameters>) {}

const compileNode = async <T extends MaterialSlot>(node: MaterialInputParameter<MaterialSlot>, context: Context): Promise<T> => {
    const result = new CachedNodeGraphResult(node, context)
    const compiled = await result.run()
    return compiled as T
}

export const applyMeshRenderSettings = (material: OutputMaterial, meshRenderSettings: MeshRenderSettings) => {
    if (!meshRenderSettings.displacementImage) return material

    const clonedMaterial = material.clone({cloneSubNode: () => true})

    const {surface, volume, displacement} = clonedMaterial.parameters
    clonedMaterial.updateParameters({surface: undefined, volume: undefined, displacement: undefined})

    const applyMeshRenderSettings = new ApplyMeshRenderSettings({
        surface,
        volume,
        displacement,
        parameters: {meshRenderSettings},
    })

    clonedMaterial.updateParameters({
        surface: getProperty(applyMeshRenderSettings, "surface"),
        volume: getProperty(applyMeshRenderSettings, "volume"),
        displacement: getProperty(applyMeshRenderSettings, "displacement"),
    })

    return clonedMaterial
}
