import {ImageColorSpace} from "@src/api-gql/data-object"
import {
    assertNodeOfType,
    GetOutputNodeType,
    ImageResource,
    IMaterialGraph,
    isNodeOfType,
    isOutputWrapper,
    isResolvedResourceWithTransientDataObject,
    MaterialGraphNode,
    MaterialGraphRootNode,
    NodeType,
    transformMaterialGraph,
    unwrapNodeOutput,
    UnwrappedMaterialGraphNode,
    wrapNodeOutput,
    WrappedMaterialGraphNode,
} from "@src/materials/material-node-graph"
import {deepCopy} from "@src/utils/utils"

type ParamValue = number | string | number[] | boolean

export type UVOffset = {
    horizontal: number
    vertical: number
    rotation: number
}

export type DecalMaskData = {
    maskImage?: ImageResource
    colorOverlayImage?: ImageResource
    widthCm: number
    heightCm: number
    invert: boolean
}

function setInputOrParam(node: UnwrappedMaterialGraphNode, name: string, value: WrappedMaterialGraphNode | ParamValue): void {
    if (typeof value === "object" && "nodeType" in value) {
        if (!node.inputs) {
            node.inputs = {}
        }
        node.inputs[name] = value
    } else if (value != null) {
        if (!node.parameters) {
            node.parameters = {}
        }
        node.parameters[name] = value
    }
}

function getInputOrParam<T = ParamValue>(node: UnwrappedMaterialGraphNode, name: string): WrappedMaterialGraphNode | T | undefined {
    const input = node.inputs?.[name]
    if (input) {
        return input
    } else {
        return node.parameters?.[name]
    }
}

function scalarMathOp(a: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number, op: string) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeMath",
        inputs: {},
        parameters: {
            "internal.operation": op,
        },
    }
    setInputOrParam(node, "Value", a)
    setInputOrParam(node, "Value_001", b)
    return wrapNodeOutput(node, "Value")
}

function vectorMathOp(a: WrappedMaterialGraphNode | number[], b: WrappedMaterialGraphNode | number[], op: string) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeVectorMath",
        inputs: {},
        parameters: {
            "internal.operation": op,
        },
    }
    setInputOrParam(node, "Vector", a)
    setInputOrParam(node, "Vector_001", b)
    return wrapNodeOutput(node, "Vector")
}

function sadd(a: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number) {
    if (a === 0.0) return b
    else if (b === 0.0) return a
    return scalarMathOp(a, b, "ADD")
}

function ssub(a: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number) {
    if (b === 0.0) return a
    return scalarMathOp(a, b, "SUBTRACT")
}

function smul(a: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number) {
    if (a === 1.0) return b
    else if (a === 0.0) return 0.0
    else if (b === 1.0) return a
    else if (b === 0.0) return 0.0
    return scalarMathOp(a, b, "MULTIPLY")
}

function sdiv(a: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number) {
    if (b === 1.0) return a
    return scalarMathOp(a, b, "DIVIDE")
}

function vadd(a: WrappedMaterialGraphNode | number[], b: WrappedMaterialGraphNode | number[]): MaterialGraphNode {
    return vectorMathOp(a, b, "ADD")
}

function vsub(a: WrappedMaterialGraphNode | number[], b: WrappedMaterialGraphNode | number[]): MaterialGraphNode {
    return vectorMathOp(a, b, "SUBTRACT")
}

function vmul(a: WrappedMaterialGraphNode | number[], b: WrappedMaterialGraphNode | number[]): MaterialGraphNode {
    return vectorMathOp(a, b, "MULTIPLY")
}

function vdiv(a: WrappedMaterialGraphNode | number[], b: WrappedMaterialGraphNode | number[]): MaterialGraphNode {
    return vectorMathOp(a, b, "DIVIDE")
}

function separateRGB(v: WrappedMaterialGraphNode) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeSeparateRGB",
        inputs: {
            Image: v,
        },
        parameters: {},
    }
    return [wrapNodeOutput(node, "R"), wrapNodeOutput(node, "G"), wrapNodeOutput(node, "B")] as const
}

function separateXYZ(v: WrappedMaterialGraphNode) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeSeparateXYZ",
        inputs: {
            Vector: v,
        },
        parameters: {},
    }
    return [wrapNodeOutput(node, "X"), wrapNodeOutput(node, "Y"), wrapNodeOutput(node, "Z")] as const
}

function combineRGB(r: WrappedMaterialGraphNode | number, g: WrappedMaterialGraphNode | number, b: WrappedMaterialGraphNode | number) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeCombineRGB",
        inputs: {},
        parameters: {},
    }
    setInputOrParam(node, "R", r)
    setInputOrParam(node, "G", g)
    setInputOrParam(node, "B", b)
    return wrapNodeOutput(node, "Image")
}

function combineXYZ(x: WrappedMaterialGraphNode | number, y: WrappedMaterialGraphNode | number, z: WrappedMaterialGraphNode | number) {
    const node: MaterialGraphNode = {
        nodeType: "ShaderNodeCombineXYZ",
        inputs: {},
        parameters: {},
    }
    setInputOrParam(node, "X", x)
    setInputOrParam(node, "Y", y)
    setInputOrParam(node, "Z", z)
    return wrapNodeOutput(node, "Vector")
}

function checkIsOutputForNode(nodeToTest: MaterialGraphNode, findNodeType: NodeType): {fromNode: UnwrappedMaterialGraphNode; fromSocket: string} | undefined {
    if (!isOutputWrapper(nodeToTest)) return undefined
    const [unwrappedNode, outputName] = unwrapNodeOutput(nodeToTest)
    if (!isNodeOfType(unwrappedNode, findNodeType)) return undefined
    return {fromNode: unwrappedNode, fromSocket: outputName}
}

function offsetUVs(uv: WrappedMaterialGraphNode, offset: UVOffset): MaterialGraphNode {
    return wrapNodeOutput(
        {
            nodeType: "Mapping",
            inputs: {
                Vector: uv,
            },
            parameters: {
                // Expected platform behavior:
                //  - Positive horizontal offset shifts the texture to the left
                //  - Positive vertical offset shifts the texture down
                //  - Positive rotation rotates texture CCW around unshifted UV origin
                "internal.vector_type": "POINT",
                Location: [offset.horizontal, offset.vertical, 0],
                Rotation: [0, 0, -offset.rotation * (Math.PI / 180)],
            },
        },
        "Vector",
    )
}

function getUVChannelForUVNode(node: UnwrappedMaterialGraphNode) {
    return node.parameters?.["internal.uv_map_index"] ?? 0
}

export function rotateNormalMapVector(normalMap: WrappedMaterialGraphNode, angleDegrees: number) {
    // increasing angleDegrees will rotate the vector CCW around the Z axis
    if (angleDegrees == 0) return normalMap
    const theta = (angleDegrees * Math.PI) / 180
    const [nr, ng, nb] = separateRGB(normalMap)
    const nx = ssub(nr, 0.5)
    const ny = ssub(ng, 0.5)
    const s = Math.sin(theta)
    const c = Math.cos(theta)
    return combineRGB(sadd(sadd(smul(nx, c), smul(ny, s)), 0.5), sadd(ssub(smul(ny, c), smul(nx, s)), 0.5), nb)
}

function transformCommon<T extends NodeType | GetOutputNodeType>(
    node: MaterialGraphNode<T>,
    traverseFn: <U extends NodeType | GetOutputNodeType>(node: MaterialGraphNode<U>) => MaterialGraphNode<U>,
) {
    let transformedNode: MaterialGraphNode
    if (isOutputWrapper(node)) {
        transformedNode = wrapNodeOutput(traverseFn(node.inputs["node"]), node.parameters["name"])
    } else {
        transformedNode = {
            nodeType: node.nodeType,
        }

        if (node.hasOwnProperty("parameters")) transformedNode.parameters = deepCopy(node.parameters)
        if (node.hasOwnProperty("inputs"))
            transformedNode.inputs =
                node.inputs &&
                Object.fromEntries(
                    Object.entries(node.inputs).map(([inputKey, inputNode]) => {
                        return [inputKey, traverseFn(inputNode)]
                    }),
                )
        if (node.hasOwnProperty("resolvedResources")) {
            transformedNode.resolvedResources = node.resolvedResources
        }
    }
    return assertNodeOfType<T>(transformedNode, node.nodeType)
}

export function transformOffsetUVs(graph: IMaterialGraph, offset: UVOffset): IMaterialGraph {
    const transformFn = (node: MaterialGraphRootNode): MaterialGraphRootNode => {
        const nodeMap = new Map<MaterialGraphNode, MaterialGraphNode>()
        const traverse = <T extends NodeType | GetOutputNodeType>(node: MaterialGraphNode<T>): MaterialGraphNode<T> => {
            if (!node) return node
            let transformedNode = nodeMap.get(node)
            if (transformedNode) return assertNodeOfType<T>(transformedNode, node.nodeType)

            transformedNode = transformCommon(node, traverse)
            const outputInfo = checkIsOutputForNode(transformedNode, "UVMap")
            if (outputInfo && outputInfo.fromSocket == "UV" && getUVChannelForUVNode(outputInfo.fromNode) === 0) {
                if (!isOutputWrapper(transformedNode)) throw new Error("Expected UVMap to be an output wrapper")
                transformedNode = offsetUVs(transformedNode, offset)
            } else if (isNodeOfType(transformedNode, "ShaderNodeNormalMap")) {
                let normalMap = transformedNode.inputs?.["Color"]
                if (normalMap) {
                    // rotate normal vector in the opposite direction, compared to UV vector
                    // (increasing rotation values rotate texture CCW, so normal map should be rotated CW)
                    normalMap = rotateNormalMapVector(normalMap, -offset.rotation)
                    transformedNode.inputs!["Color"] = normalMap
                }
            } else if (isNodeOfType(transformedNode, "ShaderNodeTextureSet")) {
                transformedNode.parameters!["internal.normal_rotation"] = -offset.rotation
            } else if (isNodeOfType(transformedNode, "BsdfPrincipled")) {
                const anisoRotParamName = "Anisotropic Rotation"
                const anisoRotAdj = offset.rotation / 360
                const anisoRotInput = transformedNode.inputs?.[anisoRotParamName]
                if (anisoRotInput) {
                    setInputOrParam(transformedNode, anisoRotParamName, sadd(anisoRotInput, anisoRotAdj))
                } else {
                    const anisoRotParam = transformedNode.parameters?.[anisoRotParamName] ?? 0
                    setInputOrParam(transformedNode, anisoRotParamName, anisoRotParam + anisoRotAdj)
                }
            }
            nodeMap.set(node, transformedNode)
            return assertNodeOfType<T>(transformedNode, node.nodeType)
        }
        return traverse(node)
    }

    return transformMaterialGraph(graph, transformFn)
}

function internalColorSpaceForImageResource(imageResource: ImageResource): "Linear" | "sRGB" {
    const colorSpace = isResolvedResourceWithTransientDataObject(imageResource)
        ? imageResource.transientDataObject.imageColorSpace
        : imageResource.mainDataObject.imageColorSpace
    return colorSpace === ImageColorSpace.Linear ? "Linear" : "sRGB"
}

function colorOverlay(color: WrappedMaterialGraphNode, imageResource: ImageResource, size: [number, number], useAlpha: boolean) {
    const uv = wrapNodeOutput<"UVMap">(
        {
            nodeType: "UVMap",
            parameters: {
                "internal.uv_map_index": 0,
            },
        },
        "UV",
    )
    const vec = wrapNodeOutput<"Mapping">(
        {
            nodeType: "Mapping",
            inputs: {
                Vector: uv,
            },
            parameters: {
                "internal.vector_type": "TEXTURE",
                Scale: [size[0], size[1], 1],
            },
        },
        "Vector",
    )

    const imgNode: MaterialGraphNode<"TexImage"> = {
        nodeType: "TexImage",
        inputs: {
            Vector: vec,
        },
        parameters: {
            "internal.image.colorspace_settings.name": internalColorSpaceForImageResource(imageResource),
            "internal.interpolation": "Linear",
            "internal.extension": "REPEAT",
        },
        resolvedResources: [
            {
                ...imageResource,
                metadata: {
                    widthCm: size[0],
                    heightCm: size[1],
                },
            },
        ],
    }
    const desat = wrapNodeOutput<"ShaderNodeHueSaturation">(
        {
            nodeType: "ShaderNodeHueSaturation",
            inputs: {
                Color: color,
            },
            parameters: {
                Hue: 0.5,
                Value: 1.0,
                Saturation: 0.0,
                Fac: 1.0,
            },
        },
        "Color",
    )
    const mul = wrapNodeOutput<"ShaderNodeMixRGB">(
        {
            nodeType: "ShaderNodeMixRGB",
            inputs: {
                Color1: desat,
                Color2: wrapNodeOutput(imgNode, "Color"),
            },
            parameters: {
                "internal.blend_type": "MULTIPLY",
                Fac: 1.0,
            },
        },
        "Color",
    )
    if (useAlpha) {
        return wrapNodeOutput<"ShaderNodeMixRGB">(
            {
                nodeType: "ShaderNodeMixRGB",
                inputs: {
                    Color1: color,
                    Color2: mul,
                    Fac: wrapNodeOutput(imgNode, "Alpha"),
                },
                parameters: {
                    "internal.blend_type": "MIX",
                },
            },
            "Color",
        )
    } else {
        return mul
    }
}

function getColorInputForBRDFNode(node: MaterialGraphNode, convertParameterToNode = true, defaultValue = [1, 1, 1]) {
    if (!(isNodeOfType(node, "BsdfPrincipled") || isNodeOfType(node, "ShaderNodeBsdfDiffuse"))) return undefined

    let colorInput = node.inputs?.["Base Color"] ?? node.inputs?.["Color"]
    const colorParam = node.parameters?.["Base Color"] ?? node.parameters?.["Color"] ?? defaultValue
    if (!colorInput && convertParameterToNode) {
        if (!colorParam) {
            return undefined
        }
        colorInput = wrapNodeOutput(
            {
                nodeType: "ShaderNodeRGB",
                parameters: {
                    Color: colorParam,
                },
            },
            "Color",
        )
    }
    return colorInput
}

function setColorInputForBRDFNode(node: MaterialGraphNode, colorInput: WrappedMaterialGraphNode | undefined): void {
    if (!(isNodeOfType(node, "BsdfPrincipled") || isNodeOfType(node, "ShaderNodeBsdfDiffuse"))) return

    let inputName = "Base Color"
    if (node.inputs?.["Color"] || node.parameters?.["Color"]) {
        // set old 'Color' input name, if it was already used
        inputName = "Color"
    }
    if (!node.inputs) node.inputs = {}
    if (colorInput) {
        node.inputs[inputName] = colorInput
    } else {
        delete node.inputs[inputName]
    }
}

export function transformColorOverlay(graph: IMaterialGraph, imageResource: ImageResource, size: [number, number], useAlpha = true): IMaterialGraph {
    const transformFn = (node: MaterialGraphRootNode): MaterialGraphRootNode => {
        const nodeMap = new Map<MaterialGraphNode, MaterialGraphNode>()
        const traverse = <T extends NodeType | GetOutputNodeType>(node: MaterialGraphNode<T>): MaterialGraphNode<T> => {
            if (!node) return node
            let transformedNode = nodeMap.get(node)
            if (transformedNode) return assertNodeOfType<T>(transformedNode, node.nodeType)

            transformedNode = transformCommon(node, traverse)
            let color = getColorInputForBRDFNode(transformedNode)
            if (color) {
                color = colorOverlay(color, imageResource, size, useAlpha)
                setColorInputForBRDFNode(transformedNode, color)
            }
            nodeMap.set(node, transformedNode)
            return assertNodeOfType<T>(transformedNode, node.nodeType)
        }
        return traverse(node)
    }

    return transformMaterialGraph(graph, transformFn)
}

function setupMask(data: DecalMaskData) {
    const setupMaskTextureNode = (imageResource: ImageResource) => {
        const uv = wrapNodeOutput<"UVMap">(
            {
                nodeType: "UVMap",
                parameters: {
                    "internal.uv_map_index": 0,
                },
            },
            "UV",
        )
        const normalizedUVs = wrapNodeOutput<"Mapping">(
            {
                nodeType: "Mapping",
                inputs: {
                    Vector: uv,
                },
                parameters: {
                    "internal.vector_type": "TEXTURE",
                    Scale: [data.widthCm, data.heightCm, 1],
                },
            },
            "Vector",
        )
        const imgNode: MaterialGraphNode<"TexImage"> = {
            nodeType: "TexImage",
            inputs: {
                Vector: normalizedUVs,
            },
            parameters: {
                "internal.image.colorspace_settings.name": internalColorSpaceForImageResource(imageResource),
                "internal.extension": "EXTEND",
                "internal.interpolation": "Linear",
            },
            resolvedResources: [
                {
                    ...imageResource,
                    metadata: {
                        widthCm: data.widthCm,
                        heightCm: data.heightCm,
                    },
                },
            ],
        }
        return {
            rgb: wrapNodeOutput(imgNode, "Color"),
            alpha: wrapNodeOutput(imgNode, "Alpha"),
        }
    }

    let maskValue: MaterialGraphNode | number
    if (data.maskImage) {
        const maskTex = setupMaskTextureNode(data.maskImage)
        // this handles both JPEG (RGB) and PNG (RGBA) cases:
        let maskLuminance: MaterialGraphNode | number = wrapNodeOutput(
            {
                nodeType: "ShaderNodeRGBToBW",
                inputs: {
                    Color: maskTex.rgb,
                },
            },
            "Value",
        )
        if (data.invert) {
            maskLuminance = ssub(1, maskLuminance)
        }
        maskValue = smul(maskLuminance, maskTex.alpha)
    } else {
        maskValue = 1.0
    }

    if (data.colorOverlayImage) {
        const overlayTex = setupMaskTextureNode(data.colorOverlayImage)
        maskValue = smul(maskValue, overlayTex.alpha)
    }

    return maskValue
}

export function transformDecalMask(graph: IMaterialGraph, maskData: DecalMaskData): IMaterialGraph {
    const transformFn = (node: MaterialGraphRootNode): MaterialGraphRootNode => {
        const nodeMap = new Map<MaterialGraphNode, MaterialGraphNode>()
        const traverse = <T extends NodeType | GetOutputNodeType>(node: MaterialGraphNode<T>): MaterialGraphNode<T> => {
            if (!node) return node
            let transformedNode = nodeMap.get(node)
            if (transformedNode) return assertNodeOfType<T>(transformedNode, node.nodeType)

            transformedNode = transformCommon(node, traverse)
            if (isNodeOfType(transformedNode, "BsdfPrincipled")) {
                let alpha = getInputOrParam<number>(transformedNode, "Alpha") ?? 1.0
                const mask = setupMask(maskData)
                alpha = smul(alpha, mask)
                setInputOrParam(transformedNode, "Alpha", alpha)
            }
            nodeMap.set(node, transformedNode)
            return assertNodeOfType<T>(transformedNode, node.nodeType)
        }
        return traverse(node)
    }

    return transformMaterialGraph(graph, transformFn)
}
