import {MeshBuffers, SceneNodes} from "@cm/template-nodes"
import {Three as THREE} from "@cm/material-nodes/three"

export class MeshBuffersWireframGeometry extends THREE.BufferGeometry {
    override type = "SegmentEdgesGeometry"

    private parameters: {meshBuffers: MeshBuffers}

    constructor(meshBuffers: MeshBuffers, channel: SceneNodes.WireframeMesh["channel"]) {
        super()

        this.parameters = {
            meshBuffers,
        }

        if (meshBuffers.indices.length === 0) return

        const {vertices: vertexData, faceIDs, indices, uvs} = meshBuffers

        const uvIndex = channel === "faces" ? 0 : parseInt(channel.slice(2))
        const uvData = uvs.at(uvIndex)

        if (channel === "faces") {
            if (indices.length !== faceIDs.length) {
                throw new Error("Indices and faceIDs length mismatch")
            }
        } else {
            if (!uvData) return

            if (uvData.length !== faceIDs.length * 2) {
                throw new Error(`Indices and ${channel} length mismatch`)
            }
        }

        const precisionPoints = 4
        const precision = Math.pow(10, precisionPoints)

        const indexArr = [0, 0, 0]
        const hashes = ["", "", ""]
        const uvHashes = ["", "", ""]

        const edgeData = new Map<string, {index0: number; index1: number; propertyHash: string}>()

        const triVertices = [new THREE.Vector3(), new THREE.Vector3(), new THREE.Vector3()]
        const [a, b, c] = triVertices

        const triUvs = [new THREE.Vector2(), new THREE.Vector2(), new THREE.Vector2()]
        const [uvA, uvB, uvC] = triUvs

        const setPosition = (v: THREE.Vector3, index: number) => {
            const offset = index * 3
            v.x = vertexData[offset]
            v.y = vertexData[offset + 1]
            v.z = vertexData[offset + 2]
        }

        const setUvs = (v: THREE.Vector2, index: number) => {
            if (uvData) {
                const offset = index * 2
                v.x = uvData[offset]
                v.y = uvData[offset + 1]
            } else {
                v.x = 0
                v.y = 0
            }
        }

        const vertices: number[] = []
        for (let i = 0; i < indices.length; i += 3) {
            indexArr[0] = indices[i]
            indexArr[1] = indices[i + 1]
            indexArr[2] = indices[i + 2]

            setPosition(a, indexArr[0])
            setPosition(b, indexArr[1])
            setPosition(c, indexArr[2])

            hashes[0] = `${Math.round(a.x * precision)},${Math.round(a.y * precision)},${Math.round(a.z * precision)}`
            hashes[1] = `${Math.round(b.x * precision)},${Math.round(b.y * precision)},${Math.round(b.z * precision)}`
            hashes[2] = `${Math.round(c.x * precision)},${Math.round(c.y * precision)},${Math.round(c.z * precision)}`

            setUvs(uvA, indexArr[0])
            setUvs(uvB, indexArr[1])
            setUvs(uvC, indexArr[2])

            uvHashes[0] = `${Math.round(uvA.x * precision)},${Math.round(uvA.y * precision)}`
            uvHashes[1] = `${Math.round(uvB.x * precision)},${Math.round(uvB.y * precision)}`
            uvHashes[2] = `${Math.round(uvC.x * precision)},${Math.round(uvC.y * precision)}`

            if (hashes[0] === hashes[1] || hashes[1] === hashes[2] || hashes[2] === hashes[0]) {
                continue
            }

            for (let j = 0; j < 3; j++) {
                const jNext = (j + 1) % 3
                const vecHash0 = hashes[j]
                const vecHash1 = hashes[jNext]
                const v0 = triVertices[j]
                const v1 = triVertices[jNext]

                const edgeHash = `${vecHash0}_${vecHash1}`
                const reverseEdgeHash = `${vecHash1}_${vecHash0}`

                const uvHash0 = uvHashes[j]
                const uvHash1 = uvHashes[jNext]
                const propertyHash = channel === "faces" ? `${faceIDs[i]}` : `${uvHash0}_${uvHash1}`
                const reversePropertyHash = channel === "faces" ? `${faceIDs[i]}` : `${uvHash1}_${uvHash0}`

                const existing = edgeData.get(reverseEdgeHash)

                if (existing) {
                    if (existing.propertyHash !== reversePropertyHash) {
                        vertices.push(v0.x, v0.y, v0.z)
                        vertices.push(v1.x, v1.y, v1.z)
                    }

                    edgeData.delete(reverseEdgeHash)
                } else {
                    const cached = edgeData.get(edgeHash)
                    if (!cached) {
                        edgeData.set(edgeHash, {
                            index0: indexArr[j],
                            index1: indexArr[jNext],
                            propertyHash: propertyHash,
                        })
                    }
                }
            }
        }

        // iterate over all remaining, unmatched edges and add them to the vertex array
        for (const edge of edgeData.values()) {
            const {index0, index1} = edge

            setPosition(a, index0)
            setPosition(b, index1)

            vertices.push(a.x, a.y, a.z)
            vertices.push(b.x, b.y, b.z)
        }

        this.setAttribute("position", new THREE.Float32BufferAttribute(vertices, 3))
    }

    override copy(source: MeshBuffersWireframGeometry): this {
        super.copy(source)

        this.parameters = Object.assign({}, source.parameters)

        return this
    }
}
