import {CubicBezierSpline, InterpolationMode, Knot, OutOfBoundsMode, SampledFunction, Vector2, Vector2Like} from "@cm/math"
import {Nodes} from "@cm/template-nodes/legacy/template-nodes"

export type ToneMappingData = Nodes.ToneMapping | undefined

export type ToneMappingFunction = (r: number, g: number, b: number) => readonly [number, number, number]

export function srgbToLinear(x: number): number {
    if (x <= 0) return 0
    else if (x >= 1) return 1
    else if (x < 0.04045) return x / 12.92
    else return Math.pow((x + 0.055) / 1.055, 2.4)
}

export function srgbToLinearNoClip(x: number): number {
    if (x <= 0) return 0
    else if (x < 0.04045) return x / 12.92
    else return Math.pow((x + 0.055) / 1.055, 2.4)
}

export function linearToSrgb(x: number): number {
    if (x <= 0) return 0
    else if (x >= 1) return 1
    else if (x < 0.0031308) return x * 12.92
    else return Math.pow(x, 1 / 2.4) * 1.055 - 0.055
}

export function linearToSrgbNoClip(x: number): number {
    if (x <= 0) return 0
    else if (x < 0.0031308) return x * 12.92
    else return Math.pow(x, 1 / 2.4) * 1.055 - 0.055
}

export function rgb2hsv(r: number, g: number, b: number): readonly [number, number, number] {
    // input: sRGB 0..1, output h [0..360], s,v [0..1]
    const v = Math.max(r, g, b)
    const c = v - Math.min(r, g, b)
    const h = c && (v === r ? (g - b) / c : v === g ? 2 + (b - r) / c : 4 + (r - g) / c)
    return [60 * (h < 0 ? h + 6 : h), v && c / v, v]
}

export function hsv2rgb(h: number, s: number, v: number): readonly [number, number, number] {
    h /= 60
    const i = Math.floor(h)
    const f = h - i
    const p = v * (1 - s)
    const q = v * (1 - f * s)
    const t = v * (1 - (1 - f) * s)
    switch (i % 6) {
        case 0:
            return [v, t, p]
        case 1:
            return [q, v, p]
        case 2:
            return [p, v, t]
        case 3:
            return [p, q, v]
        case 4:
            return [t, p, v]
        case 5:
            return [v, p, q]
        default:
            return [0, 0, 0]
    }
}

export namespace ToneMappingFunctions {
    export function compose(outer: ToneMappingFunction, inner: ToneMappingFunction) {
        return (r: number, g: number, b: number) => {
            ;[r, g, b] = inner(r, g, b)
            return outer(r, g, b)
        }
    }

    export function wrapScalarFnAsRGB(fn: (x: number) => number) {
        return (r: number, g: number, b: number) => {
            return [fn(r), fn(g), fn(b)] as const
        }
    }

    export function contrastScalar(x: number, c: number, b: number): number {
        b = 1 - b
        const min = Math.tanh((0 - b) * c)
        const max = Math.tanh((1 - b) * c)
        if (x < 0) return 0
        x = Math.sqrt(x)
        x = (Math.tanh((x - b) * c) - min) / (max - min)
        x = x * x
        return x
    }

    export function contrast(r: number, g: number, b: number, _contrast: number, _balance: number, _colorBalance: number): readonly [number, number, number] {
        return [
            contrastScalar(r, _contrast, _balance + _colorBalance),
            contrastScalar(g, _contrast, _balance),
            contrastScalar(b, _contrast, _balance - _colorBalance),
        ] as const
    }

    export function filmicScalar(x: number): number {
        const c0 = 2.51
        const c1 = 0.03
        const c2 = 2.43
        const c3 = 0.59
        const c4 = 0.14
        if (x < 0) x = 0
        x = (x * (c0 * x + c1)) / (x * (c2 * x + c3) + c4)
        if (x < 0) x = 0
        else if (x > 1) x = 1
        return x
    }

    export const filmic = wrapScalarFnAsRGB(filmicScalar)

    export const linear = (r: number, g: number, b: number) => [r, g, b] as const

    export function filmicAdvanced(
        r: number,
        g: number,
        b: number,
        _contrast: number,
        _balance: number,
        _colorBalance: number,
    ): readonly [number, number, number] {
        return [
            filmicScalar(contrastScalar(r, _contrast, _balance + _colorBalance)),
            filmicScalar(contrastScalar(g, _contrast, _balance)),
            filmicScalar(contrastScalar(b, _contrast, _balance - _colorBalance)),
        ] as const
    }

    export function coronaHighlightCompressionScalar(x: number, c: number = 1): number {
        return (x * (1 + x / (c * c))) / (1 + x)
    }

    export function coronaContrastScalar(x: number, c: number = 1): number {
        if (x < 0) return 0
        else if (x > 1) {
            // yes... corona does actually do this
            return 1
        } else if (c > 0.99 && c < 1.01) {
            // doesn't match near 1
            return x
        }
        x = Math.pow(x, 1 / 2.2)
        let p0: number
        let p1: number
        if (c >= 1) {
            const ic = 1 / c
            p0 = Math.tanh(-0.25 * c) * ic
            p1 = Math.tanh(0.25 * c) * ic
            x = Math.tanh((x - 0.5) * 0.5 * c) * ic
        } else {
            // This doesn't _exactly_ match, but c < 1 is not really used anyway
            const ic = 1 / c
            p0 = Math.tan(-0.5 * ic) * c
            p1 = Math.tan(0.5 * ic) * c
            x = Math.tan((x - 0.5) * ic) * c
        }
        x -= p0
        x *= 1 / (p1 - p0)
        x = Math.pow(x, 2.2)
        return x
    }

    export function corona(
        r: number,
        g: number,
        b: number,
        highlightCompression: number,
        contrast: number,
        saturation: number,
    ): readonly [number, number, number] {
        r = coronaHighlightCompressionScalar(r, highlightCompression)
        g = coronaHighlightCompressionScalar(g, highlightCompression)
        b = coronaHighlightCompressionScalar(b, highlightCompression)
        r = coronaContrastScalar(r, contrast)
        g = coronaContrastScalar(g, contrast)
        b = coronaContrastScalar(b, contrast)
        r = linearToSrgbNoClip(r)
        g = linearToSrgbNoClip(g)
        b = linearToSrgbNoClip(b)
        let [h, s, v] = rgb2hsv(r, g, b)
        s = Math.max(0, Math.min(1, s + saturation))
        ;[r, g, b] = hsv2rgb(h, s, v)
        r = srgbToLinearNoClip(r)
        g = srgbToLinearNoClip(g)
        b = srgbToLinearNoClip(b)
        return [r, g, b]
    }

    export function pbrNeutral(r: number, g: number, b: number): readonly [number, number, number] {
        const startCompression = 0.8 - 0.04
        const desaturation = 0.15
        const x = Math.min(r, Math.min(g, b))
        const ofs = x < 0.08 ? x - 6.25 * x * x : 0.04
        r -= ofs
        g -= ofs
        b -= ofs
        const peak = Math.max(r, Math.max(g, b))
        if (peak < startCompression) {
            return [r, g, b]
        }
        const d = 1 - startCompression
        let newPeak = 1 - (d * d) / (peak + d - startCompression)
        const k = 1 - 1 / (desaturation * (peak - newPeak) + 1)
        const sc = (newPeak / peak) * (1 - k)
        r *= sc
        g *= sc
        b *= sc
        newPeak *= k
        r += newPeak
        g += newPeak
        b += newPeak
        return [r, g, b]
    }

    export function rgbCurve(r: number, g: number, b: number, sampledFunctions: SampledFunction[], fac: number): readonly [number, number, number] {
        // we use "extrapolate" so it can not be undefined
        const rOut = sampledFunctions[0].evaluate(sampledFunctions[3].evaluate(r)!)!
        const gOut = sampledFunctions[1].evaluate(sampledFunctions[3].evaluate(g)!)!
        const bOut = sampledFunctions[2].evaluate(sampledFunctions[3].evaluate(b)!)!
        return [fac * rOut + (1.0 - fac) * r, fac * gOut + (1.0 - fac) * g, fac * bOut + (1.0 - fac) * b]
    }

    export function hsvCurve(r: number, g: number, b: number, parameters: Nodes.HueSaturationMapping["parameters"]): readonly [number, number, number] {
        const {hue, saturation, value, fac} = parameters

        const adjustFn = (hsvIn: readonly [number, number, number]) => [
            ((hsvIn[0] / 360.0 + hue + 0.5) * 360.0) % 360.0,
            Math.max(0.0, Math.min(1.0, hsvIn[1] * saturation)),
            hsvIn[2] * value,
        ]

        const hsvOutput = adjustFn(rgb2hsv(r, g, b))
        const rgbOutput = hsv2rgb(hsvOutput[0], hsvOutput[1], hsvOutput[2])

        const facAndClamp = (rgbAdj: readonly [number, number, number]) => {
            const rOut = Math.max(fac * rgbAdj[0] + (1.0 - fac) * r, 0.0)
            const gOut = Math.max(fac * rgbAdj[1] + (1.0 - fac) * g, 0.0)
            const bOut = Math.max(fac * rgbAdj[2] + (1.0 - fac) * b, 0.0)
            return [rOut, gOut, bOut] as const
        }

        return facAndClamp(rgbOutput)
    }

    export function createForToneMappingData(data: ToneMappingData): ToneMappingFunction {
        if (data) {
            if (data.mode === "filmic") {
                return ToneMappingFunctions.filmic
            } else if (data.mode === "filmic-advanced") {
                return (r, g, b) => ToneMappingFunctions.filmicAdvanced(r, g, b, data.contrast, data.balance, data.colorBalance)
            } else if (data.mode === "linear") {
                return ToneMappingFunctions.linear
            } else if (data.mode === "contrast") {
                return (r, g, b) => ToneMappingFunctions.contrast(r, g, b, data.contrast, data.balance, data.colorBalance)
            } else if (data.mode === "corona") {
                return (r, g, b) => ToneMappingFunctions.corona(r, g, b, data.highlightCompression, data.contrast, data.saturation)
            } else if (data.mode === "pbr-neutral") {
                return ToneMappingFunctions.pbrNeutral
            } else if (data.mode === "rgbCurve") {
                function computeRgbCurveSampledFunction(controlPoints: Vector2Like[]): SampledFunction {
                    const numCurveSamples = 256
                    const knots = controlPoints.map((point) => new Knot(point))
                    const spline = new CubicBezierSpline(knots)
                    const points = spline.evaluatePoints(numCurveSamples)
                    return new SampledFunction(points, InterpolationMode.Linear, OutOfBoundsMode.Extrapolate)
                }

                const sampledFunctions: SampledFunction[] = []
                for (let curveIndex = 0; curveIndex < 4; curveIndex++) {
                    const controlPoints: Vector2[] = []
                    for (let i = 0; ; i++) {
                        const key = `internal.mapping.curves[${curveIndex}].points[${i}].location`
                        const pointLocation = data.parameters[key]
                        if (!pointLocation) {
                            break
                        }
                        if (typeof pointLocation === "number" || pointLocation.length !== 2) throw Error(`Expected ${key} parameter to be an array of length 2`)
                        controlPoints.push(new Vector2(pointLocation[0], pointLocation[1]))
                    }
                    const sampledFunction = computeRgbCurveSampledFunction(controlPoints)
                    sampledFunctions.push(sampledFunction)
                }

                return (r, g, b) => ToneMappingFunctions.rgbCurve(r, g, b, sampledFunctions, data.parameters.fac)
            } else if (data.mode === "hueSaturation") {
                return (r, g, b) => ToneMappingFunctions.hsvCurve(r, g, b, data.parameters)
            }
        }
        // default to linear
        return ToneMappingFunctions.linear
    }
}

export function buildLUTEntries(size: number, range: number, fn: ToneMappingFunction, inputSrgb: boolean, outputSrgb: boolean): Float32Array {
    const array = new Float32Array(size * size * size * 3)
    let idx = 0
    const scale = range / (size - 1)
    for (let ib = 0; ib < size; ib++) {
        for (let ig = 0; ig < size; ig++) {
            for (let ir = 0; ir < size; ir++) {
                let r = ir * scale
                let g = ig * scale
                let b = ib * scale
                if (inputSrgb) {
                    r = srgbToLinearNoClip(r)
                    g = srgbToLinearNoClip(g)
                    b = srgbToLinearNoClip(b)
                }
                ;[r, g, b] = fn(r, g, b)
                if (outputSrgb) {
                    r = linearToSrgbNoClip(r)
                    g = linearToSrgbNoClip(g)
                    b = linearToSrgbNoClip(b)
                }
                array[idx++] = r
                array[idx++] = g
                array[idx++] = b
            }
        }
    }
    return array
}

export function defaultsForToneMapping(mode: Nodes.ToneMapping["mode"]): ToneMappingData {
    if (mode === "filmic") return {mode: "filmic"}
    else if (mode === "filmic-advanced") return {mode, contrast: 1, balance: 0.5, colorBalance: 0}
    else if (mode === "linear") return {mode}
    else if (mode === "contrast") return {mode, contrast: 1, balance: 0.5, colorBalance: 0}
    else if (mode === "corona") return {mode, highlightCompression: 1, contrast: 1, saturation: 0}
    else if (mode === "pbr-neutral") return {mode}
    else return undefined
}

export function compileLUT(data: ToneMappingData, size: number, range: number, inputSrgb: boolean, outputSrgb: boolean): LUTData {
    const fn = ToneMappingFunctions.createForToneMappingData(data)

    const lutData: LUTData = {
        size,
        channels: 3,
        data: buildLUTEntries(size, range, fn, inputSrgb, outputSrgb),
        range,
    }

    return lutData
}

export function loadCubeLUT(lutString: string): LUTData {
    let size!: number
    let range: number = 1
    let data!: Float32Array
    let offset = 0
    for (const line of lutString.split("\n")) {
        const tokens = line.trim().split(" ")
        if (tokens.length === 0 || tokens[0].startsWith("#")) {
            continue
        } else if (tokens[0] === "TITLE") {
            // skip
        } else if (tokens[0] === "LUT_3D_SIZE") {
            if (tokens.length !== 2) throw new Error("Expected LUT size")
            size = parseInt(tokens[1])
            data = new Float32Array(size * size * size * 3)
        } else if (tokens[0] === "DOMAIN_MIN") {
            if (tokens.length !== 4) throw new Error("Expected LUT range")
            if (parseFloat(tokens[1]) !== 0 || parseFloat(tokens[2]) !== 0 || parseFloat(tokens[3]) !== 0) throw new Error("LUT domain min must be 0")
        } else if (tokens[0] === "DOMAIN_MAX") {
            if (tokens.length !== 4) throw new Error("Expected LUT range")
            range = parseFloat(tokens[1])
            if (parseFloat(tokens[2]) !== range || parseFloat(tokens[3]) !== range) throw new Error("LUT domain max must be equal for all components")
        } else if (tokens.length === 3) {
            data[offset++] = parseFloat(tokens[0])
            data[offset++] = parseFloat(tokens[1])
            data[offset++] = parseFloat(tokens[2])
        }
    }
    if (!data) throw new Error("No data in LUT")
    if (offset !== data.length) throw new Error(`Invalid number of entries in LUT: Expected ${data.length}, got ${offset}`)
    return {
        size,
        channels: 3,
        data,
        range,
    }
}

export type ColorSpace = "linear" | "srgb"

export type LUTData = {
    size: number
    channels: 1 | 3
    data: Float32Array // length = size^3 * channels
    range: number
}
