import {Vector2, Vector2Like} from "#math/vector2"

export enum HandleType {
    Free = "free", // manually set handle position
    Smooth = "smooth", // automatically set handle position to yield a smooth curve
    SmoothNoOvershoot = "smooth-no-overshoot", // automatically set handle position to yield a smooth curve, that doesn't overshoot (useful for animation curves)
    Corner = "corner", // automatically set handle position to yield a corner
}

export class Knot {
    constructor(position: Vector2Like, handleType = HandleType.Smooth) {
        this.position = Vector2.fromVector2Like(position)
        this.handleToPrevType = handleType
        this.handleToNextType = handleType
    }

    position = new Vector2(0, 0)
    handleToPrevType: HandleType
    handleToPrevPosition = new Vector2(0, 0)
    handleToNextType: HandleType
    handleToNextPosition = new Vector2(0, 0)
}

// this class represents a 2D bezier spline that resembles those in blender's rgb curves for example
// could be extended to
//   - add hints if it is supposed to be used as a function approximation (e.g. for animation curves) or as a general purpose spline (currently it is the former which needs some additional fixing of knot-handles to make it a valid function)
//   - support 3D
//   - support different handle types
//   - support different interpolation types (e.g. quadratic)
export class CubicBezierSpline {
    constructor(public knots: Knot[]) {
        if (knots.length < 2) {
            throw Error("BezierSpline must have at least 2 knots")
        }
    }

    // call this if you have made changes to the knots
    signalKnotsChanged() {
        this.knotsNeedUpdate = true
    }

    // evaluates a single point on the cubic bezier curve; t ranges from 0 to 1 which spans the entire spline (potentially many bezier segments)
    evaluatePoint(t: number): Vector2 {
        const numSegments = this.knots.length - 1
        const segmentIndex = Math.floor(t * numSegments)
        if (segmentIndex < 0) {
            return this.knots[0].position
        }
        if (segmentIndex >= numSegments) {
            return this.knots[numSegments].position
        }
        const segmentT = (t * numSegments) % 1
        const p0 = this.knots[segmentIndex].position
        const p1 = this.knots[segmentIndex].handleToNextPosition
        const p2 = this.knots[segmentIndex + 1].handleToPrevPosition
        const p3 = this.knots[segmentIndex + 1].position
        const x = this.evaluateCubicBezier(p0.x, p1.x, p2.x, p3.x, segmentT)
        const y = this.evaluateCubicBezier(p0.y, p1.y, p2.y, p3.y, segmentT)
        return new Vector2(x, y)
    }

    // samples the entire cubic bezier curve by numPoints points
    evaluatePoints(numPoints: number): Vector2[] {
        this.updateKnots()
        const points: Vector2[] = []
        const numSegments = this.knots.length - 1
        for (let i = 0; i < numSegments; i++) {
            const k0 = this.knots[i]
            const k1 = this.knots[i + 1]
            const numSegmentPoints = Math.floor(((i + 1) * numPoints) / numSegments) - Math.floor((i * numPoints) / numSegments)
            const segmentPointsX = this.computeCurve(k0.position.x, k0.handleToNextPosition.x, k1.handleToPrevPosition.x, k1.position.x, numSegmentPoints)
            const segmentPointsY = this.computeCurve(k0.position.y, k0.handleToNextPosition.y, k1.handleToPrevPosition.y, k1.position.y, numSegmentPoints)
            const segmentPoints = segmentPointsX.map((x, i) => new Vector2(x, segmentPointsY[i]))
            points.push(...segmentPoints)
        }
        return points
    }

    private updateKnots() {
        if (this.knotsNeedUpdate) {
            this.fixKnots(this.knots)
            this.knotsNeedUpdate = false
        }
    }

    private evaluateCubicBezier(p0: number, p1: number, p2: number, p3: number, t: number): number {
        const t2 = t * t
        const t3 = t2 * t
        const mt = 1 - t
        const mt2 = mt * mt
        const mt3 = mt2 * mt
        return mt3 * p0 + 3 * mt2 * t * p1 + 3 * mt * t2 * p2 + t3 * p3
    }

    private fixKnots(knots: Knot[]) {
        knots.forEach((k, i) => this.computeHandlePosition(k, i > 0 ? knots[i - 1] : undefined, i < knots.length - 1 ? knots[i + 1] : undefined))
        this.correctFirstAndLastHandle(knots)
        for (let i = 0; i < knots.length - 1; i++) {
            const k0 = knots[i]
            const k1 = knots[i + 1]
            const p0 = k0.position
            const p1 = k1.position
            const h0 = k0.handleToNextPosition
            const h1 = k1.handleToPrevPosition
            const [ch0, ch1] = this.correctBezierSegmentHandles(p0, h0, h1, p1)
            k0.handleToNextPosition = ch0
            k1.handleToPrevPosition = ch1
        }
    }

    // analog to blender: calchandle_curvemap in colortools.c
    private computeHandlePosition(knot: Knot, prev: Knot | undefined, next: Knot | undefined) {
        if (knot.handleToPrevType === HandleType.Free && knot.handleToNextType === HandleType.Free) {
            return
        }
        if (!prev && !next) {
            throw Error("knot must have at least one neighbor")
        }
        let p1: Vector2
        const p2 = knot.position
        let p3: Vector2
        let pt: Vector2
        if (!prev) {
            p3 = next!.position
            pt = p2.mul(2).sub(p3)
            p1 = pt
        } else {
            p1 = prev.position
        }
        if (!next) {
            p1 = prev!.position
            pt = p2.mul(2).sub(p1)
            p3 = pt
        } else {
            p3 = next.position
        }

        const dvec_a = p2.sub(p1)
        const dvec_b = p3.sub(p2)
        let len_a = dvec_a.norm()
        let len_b = dvec_b.norm()
        if (len_a === 0) {
            len_a = 1
        }
        if (len_b === 0) {
            len_b = 1
        }

        if (
            knot.handleToPrevType === HandleType.Smooth ||
            knot.handleToPrevType === HandleType.SmoothNoOvershoot ||
            knot.handleToNextType === HandleType.Smooth ||
            knot.handleToNextType === HandleType.SmoothNoOvershoot
        ) {
            const tvec = dvec_b.div(len_b).add(dvec_a.div(len_a))
            const len = tvec.norm() * 2.5614 // TODO I have no idea what this number is; taken from blender
            if (len != 0) {
                if (knot.handleToPrevType === HandleType.Smooth || knot.handleToPrevType === HandleType.SmoothNoOvershoot) {
                    len_a /= len
                    knot.handleToPrevPosition = p2.add(tvec.mul(-len_a))
                    if (knot.handleToPrevType === HandleType.SmoothNoOvershoot && next && prev) {
                        // keep horizontal if extrema
                        const ydiff1 = prev.position.y - knot.position.y
                        const ydiff2 = next.position.y - knot.position.y
                        if ((ydiff1 <= 0 && ydiff2 <= 0) || (ydiff1 >= 0 && ydiff2 >= 0)) {
                            knot.handleToPrevPosition.y = knot.position.y
                        } else {
                            // handles should not be beyond y coord of two others
                            if (ydiff1 <= 0) {
                                if (prev.position.y > knot.handleToPrevPosition.y) {
                                    knot.handleToPrevPosition.y = prev.position.y
                                }
                            } else {
                                if (prev.position.y < knot.handleToPrevPosition.y) {
                                    knot.handleToPrevPosition.y = prev.position.y
                                }
                            }
                        }
                    }
                }
                if (knot.handleToNextType === HandleType.Smooth || knot.handleToNextType === HandleType.SmoothNoOvershoot) {
                    len_b /= len
                    knot.handleToNextPosition = p2.add(tvec.mul(len_b))
                    if (knot.handleToNextType === HandleType.SmoothNoOvershoot && next && prev) {
                        // keep horizontal if extrema
                        const ydiff1 = prev.position.y - knot.position.y
                        const ydiff2 = next.position.y - knot.position.y
                        if ((ydiff1 <= 0 && ydiff2 <= 0) || (ydiff1 >= 0 && ydiff2 >= 0)) {
                            knot.handleToNextPosition.y = knot.position.y
                        } else {
                            // handles should not be beyond y coord of two others
                            if (ydiff1 <= 0) {
                                if (next.position.y < knot.handleToNextPosition.y) {
                                    knot.handleToNextPosition.y = next.position.y
                                }
                            } else {
                                if (next.position.y > knot.handleToNextPosition.y) {
                                    knot.handleToNextPosition.y = next.position.y
                                }
                            }
                        }
                    }
                }
            }
        }

        if (knot.handleToPrevType === HandleType.Corner) {
            knot.handleToPrevPosition = p2.add(dvec_a.mul(-1 / 3))
        }
        if (knot.handleToNextType === HandleType.Corner) {
            knot.handleToNextPosition = p2.add(dvec_b.mul(1 / 3))
        }
    }

    // analog to blender: middle part of curvemap_make_table in colortools.c (commented with "first and last handle need correction, instead of pointing to center of next/prev, we let it point to the closest handle")
    private correctFirstAndLastHandle(knots: Knot[]) {
        // first and last handle need correction, instead of pointing to center of next/prev, we let it point to the closest handle
        if (knots.length > 2) {
            const eps = 1e-8
            if (knots[0].handleToNextType === HandleType.Smooth || knots[0].handleToNextType === HandleType.SmoothNoOvershoot) {
                const hlen = knots[0].position.distance(knots[0].handleToNextPosition) // original handle length
                // clip handle point
                const vec = knots[1].handleToPrevPosition.clone()
                if (vec.x < knots[0].position.x) {
                    vec.x = knots[0].position.x
                }
                vec.subInPlace(knots[0].position)
                const nlen = vec.norm()
                if (nlen > eps) {
                    vec.mulInPlace(hlen / nlen)
                    knots[0].handleToNextPosition = vec.add(knots[0].position)
                    knots[0].handleToPrevPosition = knots[0].position.sub(vec)
                }
            }
            const a = knots.length - 1
            if (knots[a].handleToNextType === HandleType.Smooth || knots[a].handleToNextType === HandleType.SmoothNoOvershoot) {
                const hlen = knots[a].position.distance(knots[a].handleToPrevPosition) // original handle length
                // clip handle point
                const vec = knots[a - 1].handleToNextPosition.clone()
                if (vec.x > knots[a].position.x) {
                    vec.x = knots[a].position.x
                }
                vec.subInPlace(knots[a].position)
                const nlen = vec.norm()
                if (nlen > eps) {
                    vec.mulInPlace(hlen / nlen)
                    knots[a].handleToPrevPosition = vec.add(knots[a].position)
                    knots[a].handleToNextPosition = knots[a].position.sub(vec)
                }
            }
        }
    }

    // analog to blender: BKE_curve_correct_bezpart in curve.cc
    private correctBezierSegmentHandles(p0: Vector2, p1: Vector2, p2: Vector2, p3: Vector2): [Vector2, Vector2] {
        // Calculate handle deltas.
        const h1 = p0.sub(p1)
        const h2 = p3.sub(p2)

        /* Calculate distances:
         * - len  = span of time between keyframes
         * - len1 = length of handle of start key
         * - len2 = length of handle of end key
         */
        const len = p3.x - p0.x
        const len1 = Math.abs(h1.x)
        const len2 = Math.abs(h2.x)

        // If the handles have no length, no need to do any corrections.
        if (len1 + len2 == 0) {
            return [p1, p2]
        }

        /* the two handles cross over each other, so force them
         * apart using the proportion they overlap
         */
        if (len1 + len2 > len) {
            const fac = len / (len1 + len2)
            p1 = p0.sub(h1.mul(fac))
            p2 = p3.sub(h2.mul(fac))
        }
        return [p1, p2]
    }

    // analog to blender: BKE_curve_forward_diff_bezier in curve.cc
    private computeCurve(q0: number, q1: number, q2: number, q3: number, numPoints: number): number[] {
        numPoints--
        let f = numPoints
        const rt0 = q0
        const rt1 = (3 * (q1 - q0)) / f
        f *= numPoints
        const rt2 = (3 * (q0 - 2 * q1 + q2)) / f
        f *= numPoints
        const rt3 = (q3 - q0 + 3 * (q1 - q2)) / f
        q0 = rt0
        q1 = rt1 + rt2 + rt3
        q2 = 2 * rt2 + 6 * rt3
        q3 = 6 * rt3
        const result: number[] = new Array(numPoints + 1)
        for (let a = 0; a <= numPoints; a++) {
            result[a] = q0
            q0 += q1
            q1 += q2
            q2 += q3
        }
        return result
    }

    private knotsNeedUpdate = true
}
