import {DEFAULT_FLOAT_TEXTURE_TYPE} from "@editor/helpers/scene/three-proxies/utils"
import {Three as THREE} from "@cm/material-nodes/three"
import {Pass, FullScreenQuad} from "@cm/material-nodes/three"
import {buildLUTEntries, ToneMappingFunction, ToneMappingFunctions} from "@cm/image-processing/tone-mapping"

function lensApertureOffsetMatrix(ax: number, ay: number, focusDist: number): THREE.Matrix4 {
    const lensOfs = new THREE.Vector3(ax, ay, 0)
    const m = new THREE.Matrix4().lookAt(lensOfs, new THREE.Vector3(0, 0, -focusDist), new THREE.Vector3(0, 1, 0))
    return new THREE.Matrix4().makeTranslation(lensOfs.x, lensOfs.y, 0).multiply(m)
}

const aperturePts = [
    0, 0, -0.124, 0.241, 0.277, -0.064, -0.272, -0.033, 0.002, -0.221, 0.15, 0.169, 0.147, -0.304, 0.337, 0.122, -0.185, -0.26, 0.036, 0.32, -0.264, 0.171,
    0.18, 0.027, -0.096, -0.123, 0.095, -0.121, -0.114, 0.07, 0.05, 0.184, -0.294, -0.128, 0.241, 0.203, 0.28, -0.161, -0.366, -0.01, 0.373, -0.047, -0.213,
    0.279, -0.079, 0.327, 0.05, -0.306, -0.093, -0.392, 0.181, 0.328, 0.182, -0.168, 0.337, 0.219, -0.362, 0.176, -0.001, -0.098, 0.101, -0.39, -0.027, 0.245,
    0.243, -0.321, -0.278, -0.29, -0.203, -0.164, 0.238, 0.106, 0.376, -0.144, 0.4, 0.047, -0.168, 0.366, 0.101, 0.085, 0.099, -0.219, 0.095, -0.021, -0.096,
    -0.22, -0.016, 0.402, 0.004, -0.402, 0.105, 0.389, -0.185, -0.358, 0.301, 0.031, -0.207, 0.04, -0.309, 0.258, -0.094, -0.025, -0.169, 0.151, 0.273, 0.296,
    0.112, 0.259, -0.072, 0.158, 0.005, 0.097, -0.296, 0.079, -0.389, -0.105, -0.181, -0.069, -0.34, -0.215, -0.048, -0.305, 0.313, -0.253, 0.179, -0.07,
    -0.394, 0.084,
]

function sampleAperture(diameter: number, frameIdx: number): [number, number] {
    const aidx = frameIdx * 2
    if (aidx < aperturePts.length) {
        const x = aperturePts[aidx]
        const y = aperturePts[aidx + 1]
        return [x * diameter, y * diameter]
    }
    while (true) {
        const x = Math.random() - 0.5
        const y = Math.random() - 0.5
        if (x * x + y * y < 0.25) {
            return [x * diameter, y * diameter]
        }
    }
}

function saveAndMultiplyCameraMatrix(camera: THREE.Camera, m: THREE.Matrix4) {
    camera.updateMatrixWorld()
    const oldP = camera.position.clone()
    const oldQ = camera.quaternion.clone()
    camera.matrixWorld.multiply(m)
    camera.matrixWorld.decompose(camera.position, camera.quaternion, camera.scale)
    return {oldP, oldQ}
}

function restoreCameraMatrix(camera: THREE.Camera, prev: any) {
    camera.position.copy(prev.oldP)
    camera.quaternion.copy(prev.oldQ)
    camera.updateMatrixWorld()
}

function saveAndAdjustCameraView(camera: THREE.PerspectiveCamera, fullWidth: number, fullHeight: number, offsetX: number, offsetY: number) {
    const origView = camera.view?.enabled ? camera.view : null
    // override camera.setViewOffset(...)
    camera.view = {
        enabled: true,
        fullWidth,
        fullHeight,
        offsetX: offsetX + (origView ? origView.offsetX * (fullWidth / origView.fullWidth) : 0),
        offsetY: offsetY + (origView ? origView.offsetY * (fullHeight / origView.fullHeight) : 0),
        width: fullWidth,
        height: fullHeight,
    }
    camera.updateProjectionMatrix()
    return origView
}

function restoreCameraView(camera: THREE.PerspectiveCamera, view: THREE.PerspectiveCamera["view"] | null) {
    if (view) {
        camera.view = view
        camera.updateProjectionMatrix()
    } else {
        camera.clearViewOffset()
    }
}

export class TAARenderPass extends Pass {
    accumulate = false
    maxSamples = 64
    fadeInTiming: readonly [number, number] | null = [0.2, 0.9]
    private accumulateIndex = 0
    private clearColor: THREE.Color | string | number
    private clearAlpha: number
    private copyUniforms1: any
    private copyUniforms2: any
    private fsQuad1: FullScreenQuad
    private fsQuad2: FullScreenQuad
    private sampleRenderTarget: THREE.WebGLRenderTarget | null = null
    private snapshotRenderTarget: THREE.WebGLRenderTarget | null = null
    toneMapLUT: CubeLUTTexture
    exposure = 1.0
    depthOfField?: {apertureSize: number; focusDistance: number}

    static JitterVectors = [
        0, 0, 0.495, -0.499, -0.007, 0.498, -0.5, 0.002, 0.253, 0.245, 0.247, -0.25, -0.257, -0.251, -0.249, 0.245, 0.245, 0.497, 0.007, 0.248, -0.256, -0.498,
        -0.005, -0.252, 0.254, -0.005, 0.491, -0.249, -0.257, -0.006, -0.495, 0.258, -0.129, -0.122, -0.128, -0.376, -0.374, 0.123, 0.132, 0.12, 0.123, -0.123,
        0.119, -0.38, -0.386, -0.371, 0.371, 0.375, -0.383, -0.13, -0.135, 0.376, -0.119, 0.127, 0.374, -0.372, 0.389, 0.133, 0.118, 0.378, -0.374, 0.379,
        0.386, -0.112, 0.123, -0.255, 0.255, 0.119, 0.246, 0.367, -0.132, -0.245, -0.119, 0.25, 0.25, -0.378, -0.134, 0.002, -0.002, -0.126, -0.011, 0.37,
        0.374, 0.497, 0.257, -0.128, -0.244, 0.122, 0.131, 0.256, -0.38, -0.496, 0.007, 0.126, -0.257, -0.13, -0.371, 0.254, -0.262, -0.376, 0.376, 0.253,
        0.369, -0.247, -0.003, -0.379, 0.13, -0.002, -0.131, -0.497, 0.379, 0.009, -0.254, 0.364, 0.119, -0.499, -0.381, -0.25, -0.377, -0.006, 0.496, 0.379,
        -0.498, -0.124, -0.492, 0.12, 0.489, -0.361,
    ]

    static CopyShader1 = {
        uniforms: {
            tImage: {value: null as any},
        },
        vertexShader: [
            "varying vec2 vUv;",
            "void main() {",
            "    vUv = uv;",
            "    gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );",
            "}",
        ].join("\n"),
        fragmentShader: [
            "uniform sampler2D tImage;",
            "varying vec2 vUv;",
            "#include <common>", // for rand()
            "void main() {",
            "    vec4 texel = texture2D( tImage, vUv );",
            "    texel.rgb += vec3(1.96e-3, -1.96e-3, 1.96e-3) * (rand(gl_FragCoord.xy) * 2.0 - 1.0);", // dither
            "    gl_FragColor = vec4(texel.rgb * texel.a, texel.a);",
            "}",
        ].join("\n"),
    }

    static CopyShader2 = {
        uniforms: {
            tAccum: {value: null as any},
            tSnapshot: {value: null as any},
            toneMapExposure: {value: null as any},
            tToneMapLUT: {value: null as any},
            toneMapLUTSize: {value: 1.0},
            toneMapLUTRangeScale: {value: 1.0},
            accumScale: {value: 1.0},
            mixFactor: {value: 1.0},
        },
        vertexShader: [
            "varying vec2 vUv;",
            "void main() {",
            "    vUv = uv;",
            "    gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );",
            "}",
        ].join("\n"),
        fragmentShader: [
            "uniform float accumScale;",
            "uniform float mixFactor;",
            "uniform sampler2D tAccum;",
            "uniform sampler2D tSnapshot;",
            "uniform sampler2D tToneMapLUT;",
            "uniform float toneMapLUTSize;",
            "uniform float toneMapLUTRangeScale;",
            "uniform float toneMapExposure;",
            "varying vec2 vUv;",
            // WARNING: texture interpolation precision can be limited by the hardware! Make sure the lookup tables have a decent number of points to ensure smooth gradients.
            "vec4 sampleCube(sampler2D tex, vec3 coord, float sz) {",
            "  float sz_1 = sz - 1.0;",
            "  float rsz = 1.0 / sz;",
            "  float rsz2 = rsz * rsz;",
            "  float tx = (coord.x * sz_1 + 0.5) * rsz;",
            "  float ty = (clamp(coord.y, 0.0, 1.0) * sz_1 + 0.5) * rsz2;",
            "  float _iz = coord.z * sz_1;",
            "  float iz = floor(_iz);",
            "  float zFrac = _iz - iz;",
            "  float tyz0 = ty + clamp(iz, 0.0, sz_1) * rsz;",
            "  float tyz1 = ty + clamp(iz + 1.0, 0.0, sz_1) * rsz;",
            "  vec4 z0c = texture2D(tex, vec2(tx, tyz0));",
            "  vec4 z1c = texture2D(tex, vec2(tx, tyz1));",
            "  return mix(z0c, z1c, zFrac);",
            "}",
            // "vec3 toneMap(vec3 x) {",
            // "    const float a = 2.51;",
            // "    const float b = 0.03;",
            // "    const float c = 2.43;",
            // "    const float d = 0.59;",
            // "    const float e = 0.14;",
            // "    return clamp((x * (a * x + b)) / (x * (c * x + d) + e), 0.0, 1.0);",
            // "}",
            "void main() {",
            "    vec4 texel = texture2D(tAccum, vUv);",
            "    texel = mix(texture2D(tSnapshot, vUv), texel * accumScale, mixFactor);",
            "    texel.rgb /= max(texel.a, 1e-3);", // un-premultiply alpha
            "    texel.rgb *= toneMapExposure;",
            // "    texel.rgb = toneMap(texel.rgb);",
            "    texel = LinearTosRGB(max(texel, 0.0));",
            "    texel.rgb = sampleCube(tToneMapLUT, texel.rgb * toneMapLUTRangeScale, toneMapLUTSize).rgb;",
            "    texel.rgb *= texel.a;", // re-premultiply alpha
            "    gl_FragColor = texel;",
            "}",
        ].join("\n"),
    }

    constructor(
        private scene: THREE.Scene,
        public camera: THREE.PerspectiveCamera | null,
        clearColor?: THREE.Color | string | number,
        clearAlpha?: number,
    ) {
        super()

        this.clearColor = clearColor !== undefined ? clearColor : 0x000000
        this.clearAlpha = clearAlpha !== undefined ? clearAlpha : 0

        this.copyUniforms1 = THREE.UniformsUtils.clone(TAARenderPass.CopyShader1.uniforms)
        const copyMaterial1 = new THREE.ShaderMaterial({
            uniforms: this.copyUniforms1,
            vertexShader: TAARenderPass.CopyShader1.vertexShader,
            fragmentShader: TAARenderPass.CopyShader1.fragmentShader,
            premultipliedAlpha: false,
            transparent: true,
            blending: THREE.CustomBlending,
            blendEquation: THREE.AddEquation,
            blendSrc: THREE.OneFactor,
            blendDst: THREE.OneFactor,
            depthTest: false,
            depthWrite: false,
        })
        this.fsQuad1 = new FullScreenQuad(copyMaterial1)

        this.copyUniforms2 = THREE.UniformsUtils.clone(TAARenderPass.CopyShader2.uniforms)
        const copyMaterial2 = new THREE.ShaderMaterial({
            uniforms: this.copyUniforms2,
            vertexShader: TAARenderPass.CopyShader2.vertexShader,
            fragmentShader: TAARenderPass.CopyShader2.fragmentShader,
            premultipliedAlpha: false,
            transparent: true,
            blending: THREE.CustomBlending,
            blendEquation: THREE.AddEquation,
            blendSrc: THREE.OneFactor,
            blendDst: THREE.OneFactor,
            depthTest: false,
            depthWrite: false,
        })
        this.fsQuad2 = new FullScreenQuad(copyMaterial2)

        const contrastCurve = (x: number, c: number, b: number) => {
            b = 1 - b
            const min = Math.tanh((0 - b) * c)
            const max = Math.tanh((1 - b) * c)
            if (x < 0) return 0
            x = Math.sqrt(x)
            x = (Math.tanh((x - b) * c) - min) / (max - min)
            x = x * x
            return x
        }

        const toneMap = (x: number) => {
            const c0 = 2.51
            const c1 = 0.03
            const c2 = 2.43
            const c3 = 0.59
            const c4 = 0.14
            if (x < 0) x = 0
            x = (x * (c0 * x + c1)) / (x * (c2 * x + c3) + c4)
            if (x < 0) x = 0
            else if (x > 1) x = 1
            return x
        }

        this.toneMapLUT = new CubeLUTTexture(32, 2.0, DEFAULT_FLOAT_TEXTURE_TYPE)
        this.toneMapLUT.updateWithFunction(ToneMappingFunctions.filmic)
    }

    override dispose() {
        if (this.sampleRenderTarget) {
            this.sampleRenderTarget.dispose()
            this.sampleRenderTarget = null
        }
        if (this.snapshotRenderTarget) {
            this.snapshotRenderTarget.dispose()
            this.snapshotRenderTarget = null
        }
    }

    override setSize(width: number, height: number) {
        if (this.sampleRenderTarget) this.sampleRenderTarget.setSize(width, height)
        if (this.snapshotRenderTarget) this.snapshotRenderTarget.setSize(width, height)
    }

    override render(
        renderer: THREE.WebGLRenderer,
        writeBuffer: THREE.WebGLRenderTarget,
        readBuffer: THREE.WebGLRenderTarget,
        deltaTime: number,
        maskActive: boolean,
    ) {
        if (!this.sampleRenderTarget) {
            this.sampleRenderTarget = new THREE.WebGLRenderTarget(readBuffer.width, readBuffer.height, {
                minFilter: THREE.NearestFilter,
                magFilter: THREE.NearestFilter,
                format: THREE.RGBAFormat,
                type: DEFAULT_FLOAT_TEXTURE_TYPE,
            })
            this.sampleRenderTarget.texture.name = "TAARenderPass.sample"
        }
        if (!this.snapshotRenderTarget) {
            this.snapshotRenderTarget = new THREE.WebGLRenderTarget(readBuffer.width, readBuffer.height, {
                minFilter: THREE.NearestFilter,
                magFilter: THREE.NearestFilter,
                format: THREE.RGBAFormat,
                type: DEFAULT_FLOAT_TEXTURE_TYPE,
            })
            this.snapshotRenderTarget.texture.name = "TAARenderPass.snapshot"
        }

        if (!this.accumulate) {
            this.accumulateIndex = 0
        }

        const autoClear = renderer.autoClear
        renderer.autoClear = false

        const jitterOffsets = TAARenderPass.JitterVectors
        let offsetX: number
        let offsetY: number
        if (this.accumulateIndex < jitterOffsets.length / 2) {
            const idx = this.accumulateIndex * 2 + 0
            offsetX = jitterOffsets[idx + 0]
            offsetY = jitterOffsets[idx + 1]
        } else {
            offsetX = Math.random() - 0.5
            offsetY = Math.random() - 0.5
        }

        const oldView = saveAndAdjustCameraView(this.camera!, readBuffer.width, readBuffer.height, offsetX, offsetY)

        let oldCam: any
        if (this.depthOfField) {
            const [ax, ay] = this.accumulateIndex > 0 ? sampleAperture(this.depthOfField.apertureSize, this.accumulateIndex) : [0, 0]
            const apertureM = lensApertureOffsetMatrix(ax, ay, this.depthOfField.focusDistance)
            oldCam = saveAndMultiplyCameraMatrix(this.camera!, apertureM)
        }

        // render scene to writeBuffer
        renderer.setRenderTarget(writeBuffer)
        renderer.setClearColor(0x000000, 0.0)
        renderer.clear()
        renderer.render(this.scene, this.camera!)

        // accumulate writeBuffer to sampleRenderTarget
        renderer.setRenderTarget(this.sampleRenderTarget)
        if (this.accumulateIndex === 0) {
            renderer.setClearColor(0x000000, 0.0)
            renderer.clear()
        }
        this.copyUniforms1["tImage"].value = writeBuffer.texture
        this.fsQuad1.render(renderer)

        if (this.accumulateIndex === 0) {
            renderer.setRenderTarget(this.snapshotRenderTarget)
            renderer.setClearColor(0x000000, 0.0)
            renderer.clear()
            this.copyUniforms1["tImage"].value = writeBuffer.texture
            this.fsQuad1.render(renderer)
        }

        ++this.accumulateIndex

        restoreCameraView(this.camera!, oldView)

        if (oldCam) {
            restoreCameraMatrix(this.camera!, oldCam)
        }

        this.copyUniforms2["accumScale"].value = 1.0 / this.accumulateIndex
        if (this.fadeInTiming) {
            const beginFade = this.fadeInTiming[0]
            const endFade = this.fadeInTiming[1]
            this.copyUniforms2["mixFactor"].value = Math.min(
                Math.max(0.0, (this.accumulateIndex - this.maxSamples * beginFade) / (this.maxSamples * (endFade - beginFade))),
                1.0,
            )
        } else {
            this.copyUniforms2["mixFactor"].value = 1.0
        }
        this.copyUniforms2["tAccum"].value = this.sampleRenderTarget.texture
        this.copyUniforms2["tSnapshot"].value = this.snapshotRenderTarget.texture
        this.copyUniforms2["tToneMapLUT"].value = this.toneMapLUT
        this.copyUniforms2["toneMapLUTSize"].value = this.toneMapLUT.size
        this.copyUniforms2["toneMapLUTRangeScale"].value = 1.0 / this.toneMapLUT.range
        this.copyUniforms2["toneMapExposure"].value = this.exposure
        renderer.setRenderTarget(this.renderToScreen ? null : writeBuffer)
        renderer.setClearColor(0x000000, 0.0)
        renderer.clear()
        this.fsQuad2.render(renderer)

        renderer.autoClear = autoClear
    }
}

export class CubeLUTTexture extends THREE.DataTexture {
    private data: Float16ArrayBuilder | Float32ArrayBuilder
    readonly size: number
    readonly range: number

    constructor(size: number, range: number, type: THREE.TextureDataType) {
        const builderType = type === THREE.HalfFloatType ? Float16ArrayBuilder : Float32ArrayBuilder
        const data = new builderType(size * size * size * 4)
        super(data.array, size, size * size, THREE.RGBAFormat, type)
        this.data = data
        this.size = size
        this.range = range
        this.wrapS = THREE.ClampToEdgeWrapping
        this.wrapT = THREE.ClampToEdgeWrapping
        this.minFilter = THREE.LinearFilter
        this.magFilter = THREE.LinearFilter
    }

    updateWithArray(array: Float32Array) {
        const numElements = array.length / 3
        for (let n = 0; n < numElements; n++) {
            this.data.set(n * 4 + 0, array[n * 3 + 0])
            this.data.set(n * 4 + 1, array[n * 3 + 1])
            this.data.set(n * 4 + 2, array[n * 3 + 2])
            this.data.set(n * 4 + 3, 1.0)
        }
        this.needsUpdate = true
    }

    updateWithFunction(fn: ToneMappingFunction) {
        this.updateWithArray(buildLUTEntries(this.size, this.range, fn, true, true))
    }
}

//TODO: move to utils:

class Float16ArrayBuilder {
    private tmpFloat32 = new Float32Array(1)
    private tmpUint32View = new Uint32Array(this.tmpFloat32.buffer)
    readonly array: Uint16Array
    constructor(length: number) {
        this.array = new Uint16Array(length)
    }

    set(idx: number, value: number): void {
        this.tmpFloat32[0] = value
        const x = this.tmpUint32View[0]
        let y = (x >> 16) & 0x8000 // Get the sign
        let m = (x >> 12) & 0x07ff // Keep one extra bit for rounding
        const e = (x >> 23) & 0xff // Using int is faster here
        if (e < 103) {
            // If zero, or denormal, or exponent underflows too much for a denormal half, return signed zero.
        } else if (e > 142) {
            // If NaN, return NaN. If Inf or exponent overflow, return Inf.
            y |= 0x7c00
            // If exponent was 0xff and one mantissa bit was set, it means NaN, not Inf, so make sure we set one mantissa bit too.
            y |= (e == 255 ? 0 : 1) && x & 0x007fffff
        } else if (e < 113) {
            // If exponent underflows but not too much, return a denormal
            m |= 0x0800
            // Extra rounding may overflow and set mantissa to 0 and exponent to 1, which is OK.
            y |= (m >> (114 - e)) + ((m >> (113 - e)) & 1)
        } else {
            y |= ((e - 112) << 10) | (m >> 1)
            // Extra rounding. An overflow will set mantissa to 0 and increment the exponent, which is OK.
            y += m & 1
        }
        this.array[idx] = y
    }
}

class Float32ArrayBuilder {
    readonly array: Float32Array
    constructor(length: number) {
        this.array = new Float32Array(length)
    }
    set(idx: number, value: number): void {
        this.array[idx] = value
    }
}
