import {Inlet, NotReady, Outlet} from "@src/templates/runtime-graph/slots"
import {TypeDescriptors} from "@src/templates/runtime-graph/type-descriptors"
import {NodeClassImpl} from "@src/templates/runtime-graph/types"
import {Matrix4, Vector3} from "@src/math"
import * as THREE from "three"
import {BoundsData} from "@src/geometry-processing/mesh-data"

const TD = TypeDescriptors

type ControlPoint = {position: Vector3; normal: Vector3; corner: boolean}

const sampleCurveDescriptor = {
    closed: TD.inlet(TD.Primitive<boolean>()),
    allowScaling: TD.inlet(TD.Primitive<boolean>()),
    controlPoints: TD.inlet(
        TD.Array({
            deepCompare: (a: ControlPoint, b: ControlPoint) => {
                return a.position.equals(b.position) && a.normal.equals(b.normal) && a.corner === b.corner
            },
        }),
    ),
    segmentLength: TD.inlet(TD.Identity<number | BoundsData>()),
    curvePoints: TD.outlet(
        TD.Nullable(
            TD.Identity<{
                points: Float32Array
                normals: Float32Array
                tangents: Float32Array
                segments: Float32Array
                scales: Float32Array
            }>(),
        ),
    ),
}

function toThreeVector(vector: Vector3): THREE.Vector3 {
    return new THREE.Vector3(vector.x, vector.y, vector.z)
}

class IndexedCurvePath extends THREE.CurvePath<THREE.Vector3> {
    constructor() {
        super()
    }

    private getIndexedVector(t: number, type: "position" | "tangent", optionalTarget?: THREE.Vector3) {
        const d = t * this.getLength()
        const curveLengths = this.getCurveLengths()
        let i = 0

        while (i < curveLengths.length) {
            if (curveLengths[i] >= d) {
                const diff = curveLengths[i] - d
                const curve = this.curves[i]

                const segmentLength = curve.getLength()
                const u = segmentLength === 0 ? 0 : 1 - diff / segmentLength

                const t = curve.getUtoTmapping(u, undefined as unknown as number)

                return {vector: type === "position" ? curve.getPoint(t, optionalTarget) : curve.getTangent(t, optionalTarget), curveId: i, t}
            }

            i++
        }

        throw new Error("Invalid t value")
    }

    getSpacedIndexedVector(divisions: number, type: "position" | "tangent") {
        const points: {
            vector: THREE.Vector3
            curveId: number
            t: number
        }[] = []

        for (let i = 0; i <= divisions; i++) {
            points.push(this.getIndexedVector(i / divisions, type))
        }

        if (this.autoClose && points.length > 0) points.push(points[0])

        return points
    }

    sampleIndexedVector(tSamples: number[], type: "position" | "tangent") {
        const points = tSamples.map((t) => this.getIndexedVector(t, type))

        if (this.autoClose && points.length > 0) points.push(points[0])

        return points
    }
}

export type CurvePoints = {
    points: Float32Array
    normals: Float32Array
    tangents: Float32Array
    segments: Float32Array
    scales: Float32Array
}

export type SampleCurveData = {
    curvePoints: CurvePoints | null
    matrix: Matrix4
}

export class SampleCurve implements NodeClassImpl<typeof sampleCurveDescriptor, typeof SampleCurve> {
    static descriptor = sampleCurveDescriptor
    static uniqueName = "SampleCurve"
    closed!: Inlet<boolean>
    allowScaling!: Inlet<boolean>
    controlPoints!: Inlet<ControlPoint[]>
    segmentLength!: Inlet<number | BoundsData>
    curvePoints!: Outlet<CurvePoints | null>

    run() {
        if (this.closed === NotReady || this.allowScaling === NotReady || this.controlPoints === NotReady || this.segmentLength === NotReady) {
            this.curvePoints.emitIfChanged(NotReady)
            return
        }

        const curvePath = new IndexedCurvePath()
        curvePath.autoClose = this.closed

        const indexedControlPoints = this.controlPoints.map((controlPoint, index) => ({...controlPoint, index}))
        const indexedControlPointsPerCurve: (typeof indexedControlPoints)[] = []

        if (indexedControlPoints.length >= 2) {
            const firstCornerIndex = indexedControlPoints.findIndex((point) => point.corner)

            if (firstCornerIndex !== -1) {
                const reshuffledControlPoints = this.closed
                    ? indexedControlPoints.slice(firstCornerIndex).concat(indexedControlPoints.slice(0, firstCornerIndex))
                    : indexedControlPoints

                let currentCurvePoints: typeof indexedControlPoints = []
                for (let i = 0; i < reshuffledControlPoints.length; i++) {
                    const currentCurvePoint = reshuffledControlPoints[i]
                    const {corner} = currentCurvePoint

                    currentCurvePoints.push(currentCurvePoint)

                    if ((corner && i !== 0) || i === reshuffledControlPoints.length - 1) {
                        if (i === reshuffledControlPoints.length - 1 && this.closed) {
                            if (currentCurvePoints.length >= 2) if (!corner) currentCurvePoints.push(reshuffledControlPoints[0])
                        }
                        if (currentCurvePoints.length >= 2) {
                            const curve = new THREE.CatmullRomCurve3(
                                currentCurvePoints.map(({position}) => toThreeVector(position)),
                                false,
                                "centripetal",
                            )
                            curvePath.add(curve)
                            indexedControlPointsPerCurve.push(currentCurvePoints)
                        }
                        currentCurvePoints = [currentCurvePoint]
                    }
                }

                if (this.closed && currentCurvePoints.length === 1) {
                    const currentCurvePoint = currentCurvePoints[0]
                    const {corner} = currentCurvePoint

                    if (corner) {
                        currentCurvePoints.push(reshuffledControlPoints[0])

                        const curve = new THREE.CatmullRomCurve3(
                            currentCurvePoints.map(({position}) => toThreeVector(position)),
                            false,
                            "centripetal",
                        )
                        curvePath.add(curve)
                        indexedControlPointsPerCurve.push(currentCurvePoints)
                    }
                }
            } else {
                const curve = new THREE.CatmullRomCurve3(
                    indexedControlPoints.map(({position}) => toThreeVector(position)),
                    this.closed && indexedControlPoints.length > 2,
                    "centripetal",
                )
                curvePath.add(curve)
                indexedControlPointsPerCurve.push(indexedControlPoints)
            }
        }

        if (curvePath.curves.length > 0) {
            const startOffset = typeof this.segmentLength === "number" ? this.segmentLength / 2 : -this.segmentLength.aabb[0][2]
            const endOffset = typeof this.segmentLength === "number" ? this.segmentLength / 2 : this.segmentLength.aabb[1][2]
            const itemLength = startOffset + endOffset

            const accumulatedCurveLengths = curvePath.getCurveLengths()
            const totalCurveLength = curvePath.getLength()

            const tSamples: number[] = []
            const scales: number[] = []
            curvePath.curves.forEach((curve, i) => {
                const currentCurveLength = curve.getLength()
                if (currentCurveLength > itemLength || this.allowScaling) {
                    const curveLengthOffset = i > 0 ? accumulatedCurveLengths[i - 1] : 0
                    const numSamples = this.allowScaling
                        ? Math.max(Math.floor(currentCurveLength / itemLength), 1.0)
                        : Math.floor(currentCurveLength / itemLength)
                    const unusedSpace = currentCurveLength - numSamples * itemLength
                    const addedStepSize = unusedSpace / (numSamples + 1)
                    const stepSize = itemLength + addedStepSize
                    const currentScale = this.allowScaling ? currentCurveLength / (numSamples * itemLength) : 1.0

                    for (let j = 0; j < numSamples; j++) {
                        const t = (curveLengthOffset + startOffset + addedStepSize + j * stepSize) / totalCurveLength
                        tSamples.push(t)
                        scales.push(currentScale)
                    }
                }
            })

            const points = curvePath.sampleIndexedVector(tSamples, "position")

            const normals = points.map(({curveId, t}) => {
                const indexedControlPoints = indexedControlPointsPerCurve[curveId]
                const {closed} = curvePath.curves[curveId] as THREE.CatmullRomCurve3
                const index = t * (indexedControlPoints.length - (closed ? 0 : 1))

                const start = indexedControlPoints[Math.floor(index)]
                const end = indexedControlPoints[Math.ceil(index) % indexedControlPoints.length]

                const localT = index - Math.floor(index)

                const quatStart = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), start.normal)
                const quatEnd = new THREE.Quaternion().setFromUnitVectors(new THREE.Vector3(1, 0, 0), end.normal)

                const interpolatedQuat = quatStart.slerp(quatEnd, localT)

                return new THREE.Vector3(1, 0, 0).applyQuaternion(interpolatedQuat).normalize()
            })

            const tangents = curvePath.sampleIndexedVector(tSamples, "tangent")

            const segments = points.map(({curveId, t}) => {
                const indexedControlPoints = indexedControlPointsPerCurve[curveId]
                const {closed} = curvePath.curves[curveId] as THREE.CatmullRomCurve3

                const index = t * (indexedControlPoints.length - (closed ? 0 : 1))

                const start = indexedControlPoints[Math.floor(index)]

                const localT = index - Math.floor(index)

                return start.index * (1 - localT) + (start.index + 1) * localT
            })

            this.curvePoints.emitIfChanged({
                points: new Float32Array(points.flatMap(({vector}) => [vector.x, vector.y, vector.z])),
                normals: new Float32Array(normals.flatMap((normal) => [normal.x, normal.y, normal.z])),
                tangents: new Float32Array(tangents.flatMap(({vector}) => [vector.x, vector.y, vector.z])),
                segments: new Float32Array(segments),
                scales: new Float32Array(scales),
            })
        } else {
            this.curvePoints.emitIfChanged(null)
        }
    }
}
