import {ObjectId, SceneNodes} from "@cm/template-nodes/interfaces/scene-object"
import {anyDifference, objectFieldsDifferent} from "@template-editor/helpers/change-detection"
import {getThreeObjectPart, mathIsEqual, setThreeObjectPart, ThreeObject, updateTransform} from "@template-editor/helpers/three-object"
import {ThreeSceneManagerService} from "@template-editor/services/three-scene-manager.service"
import {Three as THREE} from "@cm/material-nodes/three"
import {ThreeMesh} from "./three-mesh"
import {ThreeMeshCurveControl} from "./three-mesh-curve-control"
import {projectCurvePointsToObject} from "./three-utils"

type InstancedMeshNode = {
    id: number
    meshDirty: boolean
    instancesDirty: boolean
    threeMesh: ThreeMesh
    subMesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>
}

export class ThreeSeam extends ThreeObject<SceneNodes.Seam> {
    protected override renderObject = new THREE.Group()

    private objectCache = new Map<ObjectId, ThreeMesh>()
    private instancedMeshIds = new Map<string, number>()
    private instancedMeshCache = new Map<number, {instancedMesh: THREE.InstancedMesh; maxCount: number}>()
    private instances: THREE.Matrix4[] = []

    constructor(threeSceneManagerService: ThreeSceneManagerService, onAsyncUpdate: () => void) {
        super(threeSceneManagerService, onAsyncUpdate)

        setThreeObjectPart(this.getRenderObject(), this)
    }

    protected getInstancedMeshId(mesh: SceneNodes.Mesh, materialIndex: number) {
        const key = `${mesh.id}@${materialIndex}`
        const existing = this.instancedMeshIds.get(key)
        if (existing !== undefined) return existing
        else {
            const id = this.instancedMeshIds.size
            this.instancedMeshIds.set(key, id)
            return id
        }
    }

    protected updateInstancedMeshes(meshNodes: InstancedMeshNode[]) {
        const toDelete = new Set<number>()
        for (const [id] of this.instancedMeshCache) toDelete.add(id)

        let needsUpdate = false

        const getInstancedMesh = (meshNode: InstancedMeshNode) => {
            const {id, subMesh, meshDirty} = meshNode
            const {geometry, material, visible, receiveShadow, castShadow, renderOrder} = subMesh

            const cachedInstancedMesh = this.instancedMeshCache.get(id)
            if (cachedInstancedMesh) {
                const {instancedMesh, maxCount} = cachedInstancedMesh
                if (meshDirty || maxCount < this.instances.length) {
                    this.renderObject.remove(instancedMesh)
                    this.instancedMeshCache.delete(id)
                    instancedMesh.dispose()
                } else return instancedMesh
            }

            const maxCount = Math.max(this.instances.length * 2, 10)
            const instancedMesh = new THREE.InstancedMesh(geometry, material, maxCount)
            setThreeObjectPart(instancedMesh, this, `group${id}`)
            instancedMesh.visible = visible
            instancedMesh.receiveShadow = receiveShadow
            instancedMesh.castShadow = castShadow
            instancedMesh.renderOrder = renderOrder

            this.renderObject.add(instancedMesh)
            this.instancedMeshCache.set(id, {
                instancedMesh,
                maxCount,
            })
            return instancedMesh
        }

        for (const meshNode of meshNodes) {
            const instancedMesh = getInstancedMesh(meshNode)

            const {instancesDirty} = meshNode

            if (instancesDirty || instancedMesh.count !== this.instances.length) {
                const {threeMesh} = meshNode
                const {matrix} = threeMesh.getRenderObject()

                let curId = 0
                for (let i = 0; i < this.instances.length; i++) {
                    instancedMesh.setMatrixAt(curId, matrix.clone().premultiply(this.instances[i]))
                    curId++
                }

                instancedMesh.instanceMatrix.needsUpdate = true

                instancedMesh.count = curId

                needsUpdate = true
            }

            toDelete.delete(meshNode.id)
        }

        for (const id of toDelete) {
            const cachedInstancedMesh = this.instancedMeshCache.get(id)
            if (cachedInstancedMesh) {
                const {instancedMesh} = cachedInstancedMesh
                this.renderObject.remove(instancedMesh)
                this.instancedMeshCache.delete(id)
                instancedMesh.dispose()
                needsUpdate = true
            }
        }

        return needsUpdate
    }

    override setup(sceneNode: SceneNodes.Seam) {
        const toDelete = new Set<ObjectId>()
        for (const [id] of this.objectCache) toDelete.add(id)

        let instancesDirty = false
        return anyDifference([
            objectFieldsDifferent(
                sceneNode,
                this.parameters,
                ["transform"],
                (valueA, valueB) => mathIsEqual(valueA, valueB),
                ({transform}) => {
                    updateTransform(transform, this.renderObject)
                },
            ),
            objectFieldsDifferent(
                sceneNode,
                this.parameters,
                ["curvePoints", "meshCurveControlId"],
                (a, b) => {
                    if (a === null || typeof a === "string" || b === null || typeof b === "string") return a === b
                    return a.points === b.points && a.normals === b.normals && a.tangents === b.tangents
                },
                ({curvePoints, meshCurveControlId}) => {
                    //We need to update the world matrix of the scene to get the correct hit points even before the first render
                    this.threeSceneManagerService.getSceneReference().updateMatrixWorld()

                    const meshCurveControl = this.threeSceneManagerService.getThreeObject(meshCurveControlId)
                    if (!meshCurveControl || !(meshCurveControl instanceof ThreeMeshCurveControl))
                        throw new Error(`Mesh curve control with id ${meshCurveControlId} not found`)

                    const {meshId} = meshCurveControl.getSceneNode()
                    const meshObject = this.threeSceneManagerService.getThreeObject(meshId)
                    if (!meshObject || !(meshObject instanceof ThreeMesh)) throw new Error(`Mesh object with id ${meshId} not found`)
                    const meshScene = meshObject.getRenderObject()

                    const segmentLength = 0.1
                    const queryRange = 2 * segmentLength

                    this.instances.length = 0
                    if (curvePoints) {
                        const {hitPoints, hitNormals, valid} = projectCurvePointsToObject(curvePoints, this.renderObject.matrix, queryRange, meshScene)
                        const {scales, tangents} = curvePoints

                        for (let i = 0; i < valid.length; i++) {
                            if (!valid[i]) continue

                            const position = hitPoints[i]
                            const normal = hitNormals[i]
                            const tangent = new THREE.Vector3(tangents[i * 3], tangents[i * 3 + 1], tangents[i * 3 + 2]).projectOnPlane(normal).normalize()
                            const bitangent = new THREE.Vector3().crossVectors(normal, tangent).normalize()

                            const matrix = new THREE.Matrix4()
                                .makeBasis(bitangent, normal, tangent)
                                .setPosition(position)
                                .scale(new THREE.Vector3(1, 1, scales[i]))

                            this.instances.push(matrix)
                        }
                    }

                    instancesDirty = true
                },
            ),
            (() => {
                const toDelete = new Set<ObjectId>()
                for (const [id] of this.objectCache) toDelete.add(id)

                const noReceiveShadows = (meshes: SceneNodes.Mesh[]): SceneNodes.Mesh[] => meshes.map((x) => ({...x, receiveRealtimeShadows: false}))

                const getThreeMesh = (mesh: SceneNodes.Mesh) => {
                    const cachedThreeMesh = this.objectCache.get(mesh.id)
                    if (cachedThreeMesh) return cachedThreeMesh

                    const threeMesh = new ThreeMesh(this.threeSceneManagerService, () => {
                        const instancedMeshData: InstancedMeshNode[] = []
                        for (const mesh of noReceiveShadows(this.getSceneNode().item)) {
                            const currentMesh = getThreeMesh(mesh)

                            const meshDirty = currentMesh === threeMesh
                            for (const {subMesh, materialIndex} of currentMesh.getSubMeshes()) {
                                instancedMeshData.push({
                                    id: this.getInstancedMeshId(mesh, materialIndex),
                                    meshDirty,
                                    instancesDirty,
                                    threeMesh: currentMesh,
                                    subMesh,
                                })
                            }
                        }

                        const needsUpdate = this.updateInstancedMeshes(instancedMeshData)
                        if (needsUpdate) this.onAsyncUpdate()
                    })
                    this.objectCache.set(mesh.id, threeMesh)

                    return threeMesh
                }

                const instancedMeshData: InstancedMeshNode[] = []
                for (const mesh of noReceiveShadows(sceneNode.item)) {
                    const threeMesh = getThreeMesh(mesh)

                    const meshDirty = threeMesh.update(mesh)
                    for (const {subMesh, materialIndex} of threeMesh.getSubMeshes()) {
                        instancedMeshData.push({
                            id: this.getInstancedMeshId(mesh, materialIndex),
                            meshDirty,
                            instancesDirty,
                            threeMesh,
                            subMesh,
                        })
                    }

                    toDelete.delete(mesh.id)
                }

                const needsUpdate = this.updateInstancedMeshes(instancedMeshData)

                for (const id of toDelete) {
                    const threeMesh = this.objectCache.get(id)
                    if (threeMesh) {
                        this.objectCache.delete(id)
                        threeMesh.dispose(true)
                    }
                }

                return needsUpdate
            })(),
        ])
    }

    getSubMesh(index: number) {
        for (const child of this.renderObject.children) {
            if (child instanceof THREE.InstancedMesh) {
                const threeObjectPart = getThreeObjectPart(child)

                if (threeObjectPart && threeObjectPart.part === `group${index}`) return child as THREE.InstancedMesh<THREE.BufferGeometry, THREE.Material>
            }
        }

        throw new Error(`Control point with index ${index} not found`)
    }

    override dispose(final: boolean) {
        for (const threeMesh of this.objectCache.values()) threeMesh.dispose(true)
        this.objectCache.clear()
    }

    override onSelectionChange(selected: boolean) {}
}
