import {DeclareMaterialNode, ThreeNode, cyclesNode, materialSlots} from "@src/materials/declare-material-node"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {z} from "zod"
import {Context, color, vec3} from "@src/materials/types"
import {GetProperty, getAll} from "@src/graph-system/utils"
import {threeConvert, threeRGBColorNode, threeValueNode, threeVec3Node} from "@src/materials/three-utils"
import {TextureSet} from "./texture-set"
import {ParameterValue} from "@src/graph-system/node-graph"

export const getDefaultMaterial = () => new THREENodes.MeshPhysicalNodeMaterial()

export class BsdfPrincipled extends DeclareMaterialNode(
    {
        returns: z.object({bsdf: z.instanceof(THREENodes.MeshPhysicalNodeMaterial).or(cyclesNode)}),
        inputs: z.object({
            baseColor: materialSlots.optional(),
            metallic: materialSlots.optional(),
            roughness: materialSlots.optional(),
            specular: materialSlots.optional(),
            specularTint: materialSlots.optional(),
            ior: materialSlots.optional(),
            sheen: materialSlots.optional(),
            sheenTint: materialSlots.optional(),
            clearcoat: materialSlots.optional(),
            clearcoatRoughness: materialSlots.optional(),
            clearcoatNormal: materialSlots.optional(),
            anisotropic: materialSlots.optional(),
            anisotropicRotation: materialSlots.optional(),
            normal: materialSlots.optional(),
            tangent: materialSlots.optional(),
            surfaceMixWeight: materialSlots.optional(),
            emission: materialSlots.optional(),
            emissionStrength: materialSlots.optional(),
            alpha: materialSlots.optional(),
            transmission: materialSlots.optional(),
            transmissionRoughness: materialSlots.optional(),
            subsurface: materialSlots.optional(),
            subsurfaceColor: materialSlots.optional(),
            subsurfaceRadius: materialSlots.optional(),
            subsurfaceIor: materialSlots.optional(),
            subsurfaceAnisotropy: materialSlots.optional(),
        }),
        parameters: z.object({
            baseColor: color.optional(),
            metallic: z.number().optional(),
            roughness: z.number().optional(),
            specular: z.number().optional(),
            specularTint: z.number().optional(),
            ior: z.number().optional(),
            sheen: z.number().optional(),
            sheenTint: z.number().optional(),
            clearcoat: z.number().optional(),
            clearcoatRoughness: z.number().optional(),
            clearcoatNormal: vec3.optional(),
            anisotropic: z.number().optional(),
            anisotropicRotation: z.number().optional(),
            normal: vec3.optional(),
            tangent: vec3.optional(),
            surfaceMixWeight: z.number().optional(),
            distribution: z.enum(["GGX", "MULTI_GGX"]).optional(),
            subsurfaceMethod: z.enum(["BURLEY", "RANDOM_WALK"]).optional(),
            emission: color.optional(),
            emissionStrength: z.number().optional(),
            alpha: z.number().optional(),
            transmission: z.number().optional(),
            transmissionRoughness: z.number().optional(),
            subsurface: z.number().optional(),
            subsurfaceColor: color.optional(),
            subsurfaceRadius: vec3.optional(),
            subsurfaceIor: z.number().optional(),
            subsurfaceAnisotropy: z.number().optional(),
        }),
    },
    {
        toThree: async ({get, inputs, parameters}) => {
            const material = getDefaultMaterial()

            const {
                anisotropic,
                anisotropicRotation,
                tangent,
                surfaceMixWeight,
                alpha,
                transmissionRoughness,
                subsurface,
                subsurfaceColor,
                subsurfaceRadius,
                subsurfaceIor,
                subsurfaceAnisotropy,
                ...relevantInputs
            } = inputs

            const {
                baseColor,
                metallic,
                roughness,
                specular,
                specularTint,
                ior,
                sheen,
                sheenTint,
                clearcoat,
                clearcoatRoughness,
                clearcoatNormal,
                normal,
                emission,
                emissionStrength,
                transmission,
            } = await getAll(relevantInputs, get)

            material.colorNode = baseColor ?? threeConvert(parameters.baseColor, threeRGBColorNode) ?? material.colorNode
            material.roughnessNode = roughness ?? threeConvert(parameters.roughness, threeValueNode) ?? material.roughnessNode
            material.metalnessNode = metallic ?? threeConvert(parameters.metallic, threeValueNode) ?? material.metalnessNode

            const specularValue = specular ?? threeConvert(parameters.specular, threeValueNode)
            const iorValue = ior ?? threeConvert(parameters.ior, threeValueNode) ?? material.iorNode

            if (
                (specular && ior) ||
                (parameters.specular &&
                    parameters.ior &&
                    Math.abs(2.0 / (1.0 - Math.sqrt(0.08 * Math.max(parameters.specular, 1e-9))) - 1.0 - parameters.ior) > 0.05)
            )
                console.warn("Material uses both specular and ior! Preferring specular.")

            if (specularValue) {
                //ior = (2.0 / (1.0 - (0.08 * specular.clamp(min=1e-9)).sqrt())) - 1.0
                material.iorNode = THREENodes.sub(
                    THREENodes.div(
                        threeValueNode(2.0),
                        THREENodes.sub(
                            threeValueNode(1.0),
                            THREENodes.sqrt(THREENodes.mul(threeValueNode(0.08), THREENodes.clamp(specularValue, threeValueNode(1e-9), threeValueNode(1.0)))),
                        ),
                    ),
                    threeValueNode(1.0),
                )
            } else if (iorValue) material.iorNode = iorValue

            const specularTintValue = specularTint ?? threeConvert(parameters.specularTint, threeValueNode, (val) => val > 0)
            if (specularTintValue && material.colorNode)
                material.specularColorNode = THREENodes.mix(threeRGBColorNode({r: 1, g: 1, b: 1}), material.colorNode, specularTintValue)

            const sheenValue = sheen ?? threeConvert(parameters.sheen, threeValueNode, (val) => val > 0) ?? material.sheenNode
            if (sheenValue) {
                const sheenTintValue = sheenTint ?? threeConvert(parameters.sheenTint, threeValueNode)
                if (sheenTintValue && material.colorNode) material.sheenNode = THREENodes.mul(material.colorNode, sheenTintValue)
                else material.sheenNode = sheenValue
            }

            material.clearcoatNode = clearcoat ?? threeConvert(parameters.clearcoat, threeValueNode, (val) => val > 0) ?? material.clearcoatNode
            if (material.clearcoatNode) {
                material.clearcoatRoughnessNode =
                    clearcoatRoughness ?? threeConvert(parameters.clearcoatRoughness, threeValueNode) ?? material.clearcoatRoughnessNode
                material.clearcoatNormalNode =
                    clearcoatNormal ??
                    threeConvert(parameters.clearcoatNormal, threeVec3Node, (val) => val.x !== 0 || val.y !== 0 || val.z !== 0) ??
                    material.clearcoatNormalNode
            }

            material.normalNode =
                normal ?? threeConvert(parameters.normal, threeVec3Node, (val) => val.x !== 0 || val.y !== 0 || val.z !== 0) ?? material.normalNode //TODO: apply displacement even if original material does not have a normal map

            const emissionValue =
                emission ?? threeConvert(parameters.emission, threeRGBColorNode, (val) => val.r > 0 || val.g > 0 || val.b > 0) ?? material.emissiveNode
            if (emissionValue) {
                const emissionStrengthValue = emissionStrength ?? threeConvert(parameters.emissionStrength, threeValueNode)
                if (emissionStrengthValue) material.emissiveNode = THREENodes.mul(emissionValue, emissionStrengthValue)
                else material.emissiveNode = emissionValue
            }

            const evaluateAlpha = async (alpha: ParameterValue<ThreeNode | undefined, Context>): Promise<THREENodes.Node | undefined> => {
                if (alpha instanceof GetProperty) {
                    const {key, value} = alpha.parameters
                    if (value instanceof TextureSet && key === "alpha") {
                        if (value.hasAlpha()) return get(alpha as ParameterValue<ThreeNode | undefined, Context>)
                        else return undefined
                    }
                }

                return get(alpha as ParameterValue<ThreeNode | undefined, Context>)
            }

            const alphaValue =
                (await evaluateAlpha(inputs.alpha)) ?? threeConvert(parameters.alpha, threeValueNode, (val) => val < 0.95) ?? material.opacityNode
            const transmissionValue = transmission ?? threeConvert(parameters.transmission, threeValueNode, (val) => val > 0.05) ?? material.transmissionNode

            if (alphaValue || transmissionValue) {
                material.transparent = true
                if (alphaValue) {
                    if (transmissionValue) console.warn("Material uses both alpha and transmission! Preferring alpha.")
                    material.opacityNode = alphaValue
                } else if (transmissionValue) {
                    material.transmissionNode = transmissionValue
                    //@ts-ignore
                    material.transmission = 0.01 //This is just to indicate to three that the material is translucent
                }
            }

            return {bsdf: material}
        },
    },
) {}
