import {Injectable} from "@angular/core"
import {ThreeMesh} from "@app/template-editor/helpers/three-mesh"
import {SceneManagerService, SceneNodePart} from "@app/template-editor/services/scene-manager.service"
import {ThreeSceneManagerService} from "@app/template-editor/services/three-scene-manager.service"
import {GLTFBuilder} from "@cm/lib/gltf/gltf-builder"
import {IMaterialData, keyForMeshMaterialData} from "@cm/lib/templates/interfaces/material-data"
import {SceneNodes} from "@cm/lib/templates/interfaces/scene-object"
import {TemplateGraph} from "@cm/lib/templates/nodes/template-graph"
import {Parameters} from "@cm/lib/templates/nodes/template-instance"
import * as THREE from "three"
import {LegacyMaterialConverter} from "@cm/lib/materials/legacy-material-converter"
import {MaterialNode} from "@cm/lib/materials/declare-material-node"
import {OutputMaterial} from "@cm/lib/materials/nodes/output-material"
import {Mapping, getMappingMatrix} from "@cm/lib/materials/nodes/mapping"
import {compressMeshGLTF} from "@app/editor/helpers/mesh-processing"
import {firstValueFrom} from "rxjs"
import {UVMap} from "@cm/lib/materials/nodes/uv-map"
import {GetProperty} from "@cm/lib/graph-system/utils"
import {SceneProperties} from "@cm/lib/templates/nodes/scene-properties"
import {Node} from "@cm/lib/templates/node-types"
import {getDefaultMaterial} from "@app/template-editor/services/three-material-manager.service"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"

const DEFAULT_GEOMETRY_COMPRESSION_BIT_DEPTH = 14

@Injectable({
    providedIn: "root",
})
export class GltfExportService {
    constructor() {}

    async exportGltfFile(
        templateGraph: TemplateGraph,
        parameters: Parameters,
        sceneManagerService: SceneManagerService,
        threeSceneManagerService: ThreeSceneManagerService,
    ): Promise<ArrayBuffer> {
        const clonedTemplateGraph = templateGraph.clone({cloneSubNode: () => true})
        const previousSceneProperties = clonedTemplateGraph.parameters.nodes.parameters.list.filter(
            (node: Node): node is SceneProperties => node instanceof SceneProperties,
        )
        for (const sceneProperty of previousSceneProperties) clonedTemplateGraph.parameters.nodes.removeEntry(sceneProperty)

        const sceneProperties = new SceneProperties({
            maxSubdivisionLevel: 9999,
            maxSubdivisionLevelOnMobile: 9999,
            uiColor: [0, 0, 0],
            uiStyle: "default",
            iconSize: 24,
            enableAr: false,
            enableSalesEnquiry: false,
            textureResolution: "2000px",
            textureFiltering: false,
            enableRealtimeShadows: false,
            enableRealtimeLights: false,
            enableRealtimeMaterials: true,
            enableOnboardingHint: false,
            enableGltfDownload: false,
            enableStlDownload: false,
            enablePdfGeneration: false,
            enableSnapshot: false,
            enableFullscreen: false,
            environmentMapMode: "full",
            showAnnotations: false,
            enableAdaptiveSubdivision: false,
        })
        clonedTemplateGraph.parameters.nodes.addEntry(sceneProperties)

        sceneManagerService.$templateGraph.set(clonedTemplateGraph)
        sceneManagerService.$instanceParameters.set(parameters)
        threeSceneManagerService.materialManagerService.progressiveTextureLoading.set(false)

        sceneManagerService.compileTemplate()
        await sceneManagerService.sync()

        const meshes = sceneManagerService
            .$scene()
            .filter((x): x is SceneNodes.Mesh => SceneNodes.Mesh.is(x) && (!x.isProcedural || x.isDecal))
            .map<SceneNodePart>((x) => ({
                sceneNode: x,
                part: "root",
            }))

        const threeMeshes = threeSceneManagerService
            .getObjectsFromSceneNodes(meshes)
            .map((x) => x.threeObject)
            .filter((x): x is ThreeMesh => x instanceof ThreeMesh)

        const builder = new GLTFBuilder({
            generator: "colormass-exportScene",
        })
        builder.useExtension("KHR_draco_mesh_compression")

        const gltfScene = builder.addScene({
            nodes: [],
        })

        const renderer = new THREE.WebGLRenderer({
            canvas: threeSceneManagerService.getRenderer().domElement,
            powerPreference: "high-performance",
            antialias: false,
            alpha: true,
            premultipliedAlpha: false,
        })
        renderer.shadowMap.enabled = false

        const materials = new Map<string, [number, THREE.Matrix4]>()
        for (const threeMesh of threeMeshes) {
            const parameters = threeMesh.getSceneNode()
            const {materialMap, meshRenderSettings} = parameters

            for (const [materialIndex, materialData] of materialMap) {
                const mesh = threeMesh.getSubMesh(materialIndex)
                const material = mesh.material
                if (materialData) {
                    const key = keyForMeshMaterialData(materialData, meshRenderSettings)
                    if (materials.has(key)) continue

                    const mappingMatrix = getUVMappingMatrix(materialData)
                    const mappingMatrixInverse = mappingMatrix.clone().invert()

                    const uvTopLeft = new THREE.Vector4(0, 0, 0, 1).applyMatrix4(mappingMatrixInverse)
                    const uvTopRight = new THREE.Vector4(1, 0, 0, 1).applyMatrix4(mappingMatrixInverse)
                    const uvBottomRight = new THREE.Vector4(1, 1, 0, 1).applyMatrix4(mappingMatrixInverse)
                    const uvBottomLeft = new THREE.Vector4(0, 1, 0, 1).applyMatrix4(mappingMatrixInverse)

                    const {diffuseMap, normalMap, metallicRoughnessMap, emissionMap} = await renderPbrMaps(
                        renderer,
                        material,
                        new THREE.Vector2(uvTopLeft.x, uvTopLeft.y),
                        new THREE.Vector2(uvTopRight.x, uvTopRight.y),
                        new THREE.Vector2(uvBottomRight.x, uvBottomRight.y),
                        new THREE.Vector2(uvBottomLeft.x, uvBottomLeft.y),
                        1024,
                        1024,
                    )

                    const gltfMaterial = builder.addMaterial({
                        doubleSided: materialData.side === "double",
                    })

                    const diffuseTexture = await makeTexture(builder, diffuseMap)
                    const normalTexture = await makeTexture(builder, normalMap)
                    const metallicRoughnessTexture = await makeTexture(builder, metallicRoughnessMap)
                    gltfMaterial.data.pbrMetallicRoughness = {
                        baseColorTexture: {index: diffuseTexture.id},
                        metallicRoughnessTexture: {index: metallicRoughnessTexture.id},
                    }
                    gltfMaterial.data.normalTexture = {index: normalTexture.id}

                    if (material.transparent) {
                        if (material.alphaTest > 0) {
                            gltfMaterial.data.alphaMode = "MASK"
                            gltfMaterial.data.alphaCutoff = material.alphaTest
                        } else {
                            gltfMaterial.data.alphaMode = "BLEND"
                        }
                    }

                    if (emissionMap) {
                        const emissiveTexture = await makeTexture(builder, emissionMap)
                        gltfMaterial.data.emissiveTexture = {index: emissiveTexture.id}
                        gltfMaterial.data.emissiveFactor = [1, 1, 1]
                    }

                    materials.set(key, [gltfMaterial.id, mappingMatrix])
                } else {
                    const key = `default-${materialIndex}`

                    if (materials.has(key)) continue

                    const material = getDefaultMaterial(materialIndex)
                    const {color, roughness, metalness} = material

                    const gltfMaterial = builder.addMaterial({})

                    gltfMaterial.data.pbrMetallicRoughness = {
                        baseColorFactor: [color.r, color.g, color.b, 1],
                        metallicFactor: metalness,
                        roughnessFactor: roughness,
                    }

                    materials.set(key, [gltfMaterial.id, new THREE.Matrix4().identity()])
                }
            }
        }

        for (const threeMesh of threeMeshes) {
            const parameters = threeMesh.getSceneNode()
            const {materialMap, meshRenderSettings} = parameters

            for (const [materialIndex, materialData] of materialMap) {
                const mesh = threeMesh.getSubMesh(materialIndex)

                const gltfMaterialData = (() => {
                    if (!materialData) return materials.get(`default-${materialIndex}`)

                    const key = keyForMeshMaterialData(materialData, meshRenderSettings)
                    return materials.get(key)
                })()
                const gltfMaterialID = gltfMaterialData?.[0]
                const gltfMaterialMatrix = gltfMaterialData?.[1]

                const positionArray = mesh.geometry.attributes.position.array
                if (!(positionArray instanceof Float32Array)) throw new Error("Position array is not a Float32Array")

                const normalArrayRaw = mesh.geometry.attributes.normal.array
                if (!(normalArrayRaw instanceof Float32Array)) throw new Error("Normal array is not a Float32Array")
                const normalArray = fixNormals(normalArrayRaw, materialData?.side === "back")

                const uvArrayRaw = mesh.geometry.attributes.uv.array
                if (!(uvArrayRaw instanceof Float32Array)) throw new Error("UV array is not a Float32Array")
                const uvArray = gltfMaterialMatrix ? transformUVs(uvArrayRaw, gltfMaterialMatrix) : uvArrayRaw

                const geometryCompressionBitDepth = DEFAULT_GEOMETRY_COMPRESSION_BIT_DEPTH

                const dracoFile = await firstValueFrom(
                    compressMeshGLTF(sceneManagerService.workerService, positionArray, normalArray, uvArray, geometryCompressionBitDepth),
                )

                const gltfDracoBuf = builder.addBuffer(
                    {
                        byteLength: dracoFile.byteLength,
                    },
                    dracoFile,
                )
                const gltfDracoView = builder.addBufferView({
                    buffer: gltfDracoBuf.id,
                    byteLength: dracoFile.byteLength,
                })
                const gltfMesh = builder.addMesh({
                    primitives: [
                        {
                            attributes: {
                                POSITION: makeAccessor(builder, positionArray, 3).id,
                                NORMAL: makeAccessor(builder, normalArray, 3).id,
                                TEXCOORD_0: gltfMaterialID !== undefined ? makeAccessor(builder, uvArray, 2).id : undefined,
                            },
                            indices: makeIndicesAccessor(builder, mesh.geometry.attributes.position).id,
                            material: gltfMaterialID,
                            extensions: {
                                KHR_draco_mesh_compression: {
                                    bufferView: gltfDracoView.id,
                                    attributes: {
                                        POSITION: 0, //TODO: get these attribute IDs from dracoWriter.cpp!
                                        NORMAL: 1,
                                        TEXCOORD_0: gltfMaterialID !== undefined ? 2 : undefined,
                                    },
                                },
                            },
                        },
                    ],
                })
                const gltfNode = builder.addNode({
                    matrix: exportMatrix(threeMesh.getRenderObject().matrix),
                    mesh: gltfMesh.id,
                })
                gltfScene.data.nodes.push(gltfNode.id)

                console.log("Exported mesh", mesh)
            }
        }

        return builder.generateGLB()
    }
}

const getUVMappingMatrix = (materialData: IMaterialData) => {
    const materialConverter = new LegacyMaterialConverter()
    const materialGraph = materialConverter.convertMaterialGraph(materialData.materialGraph)
    const materialNodes = getAllMaterialNodes(materialGraph)

    const getPropertyParent = (node: GetProperty) => {
        if (node.parents.size !== 1) throw new Error("Get property node must have exactly one parent")
        const [propertyParent] = node.parents
        return propertyParent
    }

    const getLargestMappingNode = (mappingNodes: Mapping[]) => {
        const scaleArea = (mappingNode: Mapping) => {
            const mappingMatrix = getMappingMatrix(mappingNode.parameters.parameters)
            const scale = new THREE.Vector3()
            mappingMatrix.decompose(new THREE.Vector3(), new THREE.Quaternion(), scale)
            scale.x = 1.0 / scale.x
            scale.y = 1.0 / scale.y
            scale.z = 1.0 / scale.z

            return scale.x * scale.y
        }

        return mappingNodes.length > 1 ? mappingNodes.reduce((prev, curr) => (scaleArea(prev) < scaleArea(curr) ? curr : prev)) : mappingNodes.at(0)
    }

    let mappingNode = getLargestMappingNode(
        (() => {
            const mappingNodeCandidates = new Set<Mapping>()

            const uvMaps = materialNodes.filter(
                (x): x is UVMap => x instanceof UVMap && (x.parameters.parameters.uvMapIndex === undefined || x.parameters.parameters.uvMapIndex === 0),
            )
            for (const uvMap of uvMaps) {
                for (const parent of uvMap.parents) {
                    if (parent instanceof GetProperty) {
                        const propertyParent = getPropertyParent(parent)
                        if (propertyParent instanceof Mapping) mappingNodeCandidates.add(propertyParent)
                    }
                }
            }

            const unConnectedMappingNodes = materialNodes.filter((x): x is Mapping => x instanceof Mapping && x.parameters.vector === undefined)
            for (const mappingNode of unConnectedMappingNodes) {
                mappingNodeCandidates.add(mappingNode)
            }

            return [...mappingNodeCandidates]
        })(),
    )

    let mappingMatrix = new THREE.Matrix4().identity()
    while (mappingNode) {
        mappingMatrix = getMappingMatrix(mappingNode.parameters.parameters).multiply(mappingMatrix)

        const nextMappingNodeCandidates = new Set<Mapping>()
        for (const parent of mappingNode.parents) {
            if (parent instanceof GetProperty) {
                const propertyParent = getPropertyParent(parent)
                if (propertyParent instanceof Mapping) nextMappingNodeCandidates.add(propertyParent)
            }
        }

        mappingNode = getLargestMappingNode([...nextMappingNodeCandidates])
    }

    return mappingMatrix
}

const getAllMaterialNodes = (materialGraph: OutputMaterial) => {
    const visited = new Set<MaterialNode>()
    const pending = new Set<MaterialNode>([materialGraph])
    while (pending.size > 0) {
        const [node] = pending

        pending.delete(node)
        visited.add(node)
        for (const child of node.children) {
            pending.add(child as MaterialNode)
        }
    }

    return [...visited]
}

const fixNormals = (array: Float32Array, invert: boolean): Float32Array => {
    const elemSize = 3
    const count = array.length / elemSize
    const newArray = new Float32Array(array.length)

    for (let i = 0; i < count; i++) {
        const idx = i * elemSize
        let [x, y, z] = array.subarray(idx, idx + elemSize)

        const magnitudeSquared = x * x + y * y + z * z

        if (magnitudeSquared < 1e-3) {
            ;[x, y, z] = [0, 0, 1]
        } else if (magnitudeSquared < 0.9999) {
            const scale = 1 / Math.sqrt(magnitudeSquared)
            ;[x, y, z] = [x * scale, y * scale, z * scale]
        }

        if (invert) [x, y, z] = [-x, -y, -z]

        newArray.set([x, y, z], idx)
    }

    return newArray
}

const transformUVs = (array: Float32Array, matrix: THREE.Matrix4): Float32Array => {
    const elemSize = 2
    const count = array.length / elemSize
    const newArray = new Float32Array(array.length)

    for (let i = 0; i < count; i++) {
        const idx = i * elemSize
        const [u, v] = array.subarray(idx, idx + elemSize)

        const mappedUv = new THREE.Vector4(u, v, 0, 1).applyMatrix4(matrix)

        newArray.set([mappedUv.x, mappedUv.y], idx)
    }

    return newArray
}

const makeTexture = async (builder: GLTFBuilder, blob: Blob) => {
    const {type} = blob
    if (type !== "image/png" && type !== "image/jpeg") throw new Error(`Unsupported image type: ${type}`)
    const data = new Uint8Array(await blob.arrayBuffer())

    type GLTFTextureType = "image/png" | "image/jpeg"
    const mimeType = type as GLTFTextureType

    const gltfImageBuffer = builder.addBuffer(
        {
            byteLength: data.byteLength,
        },
        data,
    )
    const gltfImageBufferView = builder.addBufferView({
        buffer: gltfImageBuffer.id,
        byteLength: data.byteLength,
    })
    const gltfImage = builder.addImage({
        bufferView: gltfImageBufferView.id,
        mimeType,
    })
    // const gltfSampler = builder.addSampler({
    //     minFilter: GLTFBuilder.SamplerFilter.LINEAR_MIPMAP_LINEAR,
    //     magFilter: GLTFBuilder.SamplerFilter.LINEAR,
    //     wrapS: GLTFBuilder.SamplerWrap.REPEAT,
    //     wrapT: GLTFBuilder.SamplerWrap.REPEAT
    // });
    const gltfTexture = builder.addTexture({
        source: gltfImage.id,
        //sampler: gltfSampler.id,
    })

    return gltfTexture
}

const makeAccessor = (builder: GLTFBuilder, array: Float32Array, elemSize: number) => {
    const count = array.length / elemSize
    const rangeMin: number[] = []
    const rangeMax: number[] = []

    for (let elemOffset = 0; elemOffset < elemSize; elemOffset++) {
        let min = Infinity
        let max = -Infinity

        for (let i = 0; i < count; i++) {
            const val = array[i * elemSize + elemOffset]
            min = Math.min(min, val)
            max = Math.max(max, val)
        }

        rangeMin.push(min)
        rangeMax.push(max)
    }

    const getType = () => {
        switch (elemSize) {
            case 1:
                return "SCALAR"
            case 2:
                return "VEC2"
            case 3:
                return "VEC3"
            case 4:
                return "VEC4"
            default:
                throw new Error("Invalid element size")
        }
    }

    const accessor = builder.addAccessor({
        type: getType(),
        componentType: GLTFBuilder.ComponentType.FLOAT,
        count,
        min: rangeMin,
        max: rangeMax,
    })

    return accessor
}

const makeIndicesAccessor = (builder: GLTFBuilder, positionAttr: THREE.BufferAttribute | THREE.InterleavedBufferAttribute) => {
    const count = positionAttr.count

    const array = new Uint32Array(count)
    for (let idx = 0; idx < array.length; idx++) {
        array[idx] = idx
    }

    const accessor = builder.addAccessor({
        type: "SCALAR",
        componentType: GLTFBuilder.ComponentType.UNSIGNED_INT,
        count,
    })

    return accessor
}

const exportMatrix = (matrix: THREE.Matrix4) => {
    const identityMatrix = new THREE.Matrix4().identity()
    const scaleMatrix = new THREE.Matrix4().makeScale(0.01, 0.01, 0.01) // convert from centimeters to meters

    const scaledMatrix = scaleMatrix.multiply(matrix)
    // GLTF validation requires that the node matrix is omitted if it is equal to identity
    if (scaledMatrix.equals(identityMatrix)) {
        return undefined
    } else {
        return scaledMatrix.toArray()
    }
}

const createQuadGeometry = (
    positionTopLeft: THREE.Vector2,
    positionTopRight: THREE.Vector2,
    positionBottomRight: THREE.Vector2,
    positionBottomLeft: THREE.Vector2,
    uvTopLeft: THREE.Vector2,
    uvTopRight: THREE.Vector2,
    uvBottomRight: THREE.Vector2,
    uvBottomLeft: THREE.Vector2,
) => {
    const vertices = [
        positionTopLeft.x,
        positionTopLeft.y,
        0,
        positionTopRight.x,
        positionTopRight.y,
        0,
        positionBottomRight.x,
        positionBottomRight.y,
        0,
        positionBottomLeft.x,
        positionBottomLeft.y,
        0,
    ]
    const normals = [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1]
    const uvs = [uvTopLeft.x, uvTopLeft.y, uvTopRight.x, uvTopRight.y, uvBottomRight.x, uvBottomRight.y, uvBottomLeft.x, uvBottomLeft.y]
    const indices = [0, 1, 2, 2, 3, 0]

    const geometry = new THREE.BufferGeometry()
    geometry.setIndex(indices)
    geometry.setAttribute("position", new THREE.Float32BufferAttribute(vertices, 3))
    geometry.setAttribute("normal", new THREE.Float32BufferAttribute(normals, 3))
    geometry.setAttribute("uv", new THREE.Float32BufferAttribute(uvs, 2))

    return geometry
}

const renderPbrMaps = async (
    renderer: THREE.WebGLRenderer,
    material: THREE.Material,
    uvTopLeft: THREE.Vector2,
    uvTopRight: THREE.Vector2,
    uvBottomRight: THREE.Vector2,
    uvBottomLeft: THREE.Vector2,
    width: number,
    height: number,
) => {
    const renderScene = new THREE.Scene()

    const patchedMaterial = material.clone()
    patchedMaterial.alphaTest = 0

    const planeMesh = new THREE.Mesh(
        createQuadGeometry(
            new THREE.Vector2(-0.5, -0.5),
            new THREE.Vector2(0.5, -0.5),
            new THREE.Vector2(0.5, 0.5),
            new THREE.Vector2(-0.5, 0.5),
            uvTopLeft,
            uvTopRight,
            uvBottomRight,
            uvBottomLeft,
        ),
        patchedMaterial,
    )
    renderScene.add(planeMesh)

    const renderCamera = new THREE.OrthographicCamera(-0.5, 0.5, 0.5, -0.5, 0.1, 2)
    renderCamera.position.z = 1

    const previousRenderTarget = renderer.getRenderTarget()

    type OutputMapType = "diffuse" | "normal" | "metallicRoughness" | "emission"

    const renderMap = async (outputMapName: OutputMapType) => {
        const renderTarget = new THREE.WebGLRenderTarget(width, height, {
            generateMipmaps: false,
            minFilter: THREE.NearestFilter,
            magFilter: THREE.NearestFilter,
            format: THREE.RGBAFormat,
            colorSpace: outputMapName === "diffuse" ? THREE.SRGBColorSpace : THREE.NoColorSpace,
        })
        renderer.setRenderTarget(renderTarget)

        patchedMaterial.customProgramCacheKey = () => material.customProgramCacheKey() + outputMapName
        patchedMaterial.onBeforeCompile = (shader, renderer) => {
            const origHook = material.onBeforeCompile
            if (origHook) origHook(shader, renderer)

            const patchPosition = shader.fragmentShader.lastIndexOf("}")
            if (patchPosition < 0) throw Error("Failed to find location to patch shader!")

            const getPatch = () => {
                switch (outputMapName) {
                    case "diffuse":
                        return "diffuseColor"
                    case "normal":
                        return "vec4((normal) * 0.5 + 0.5, 1.0)"
                    case "metallicRoughness":
                        return "vec4(0.0, roughnessFactor, metalnessFactor, 1.0)"
                    case "emission":
                        return "vec4(totalEmissiveRadiance, 1.0)"
                    default:
                        throw new Error(`Unsupported output map name: ${outputMapName}`)
                }
            }

            const patchedShader =
                shader.fragmentShader.slice(0, patchPosition) +
                `gl_FragColor = linearToOutputTexel(${getPatch()});` +
                shader.fragmentShader.slice(patchPosition)
            shader.fragmentShader = patchedShader
        }

        renderer.compile(renderScene, renderCamera)
        renderer.render(renderScene, renderCamera)
        const dataArray = new Uint8Array(4 * width * height)
        renderer.readRenderTargetPixels(renderTarget, 0, 0, width, height, dataArray)
        renderTarget.dispose()

        const canvas = document.createElement("canvas")
        canvas.width = width
        canvas.height = height

        const ctx = canvas.getContext("2d")
        if (!ctx) throw new Error("Failed to get 2d context from canvas")

        const imageData = new ImageData(new Uint8ClampedArray(dataArray), width, height)
        ctx.putImageData(imageData, 0, 0)

        const blob = await new Promise<Blob>((resolve, reject) =>
            canvas.toBlob(
                function (blob) {
                    if (!blob) reject("Failed to create blob")
                    else resolve(blob)
                },
                outputMapName === "diffuse" && patchedMaterial.transparent ? "image/png" : "image/jpeg",
            ),
        )

        return blob
    }

    const [diffuseMap, normalMap, metallicRoughnessMap, emissionMap] = await Promise.all([
        renderMap("diffuse"),
        renderMap("normal"),
        renderMap("metallicRoughness"),
        (() => {
            const isColorSet = (color: THREE.Color) => color.r !== 0 || color.g !== 0 || color.b !== 0
            if (
                (material instanceof THREENodes.MeshStandardNodeMaterial && material.emissiveNode) ||
                (material instanceof THREE.MeshStandardMaterial && isColorSet(material.emissive))
            )
                return renderMap("emission")
            else return undefined
        })(),
    ])

    planeMesh.geometry.dispose()
    planeMesh.material.dispose()

    renderer.setRenderTarget(previousRenderTarget)

    return {
        diffuseMap,
        normalMap,
        metallicRoughnessMap,
        emissionMap,
    }
}
