import {Pass, FullScreenQuad} from "three/examples/jsm/postprocessing/Pass"
import * as THREE from "three"
import {getJitterVector} from "@app/template-editor/helpers/jitter"
import {getApertureVector} from "./aperture"
import {DEFAULT_FLOAT_TEXTURE_TYPE} from "./three-utils"
import {renderHelperObjects} from "@app/template-editor/helpers/helper-objects"

function saveAndAdjustCameraView(camera: THREE.PerspectiveCamera, fullWidth: number, fullHeight: number, offsetX: number, offsetY: number) {
    const origView = camera.view?.enabled ? camera.view : null
    // override camera.setViewOffset(...)
    camera.view = {
        enabled: true,
        fullWidth,
        fullHeight,
        offsetX: offsetX + (origView ? origView.offsetX * (fullWidth / origView.fullWidth) : 0),
        offsetY: offsetY + (origView ? origView.offsetY * (fullHeight / origView.fullHeight) : 0),
        width: fullWidth,
        height: fullHeight,
    }
    camera.updateProjectionMatrix()
    return origView
}

function restoreCameraView(camera: THREE.PerspectiveCamera, view: THREE.PerspectiveCamera["view"] | null) {
    if (view) {
        camera.view = view
        camera.updateProjectionMatrix()
    } else {
        camera.clearViewOffset()
    }
}

function lensApertureOffsetMatrix(ax: number, ay: number, focalDistance: number): THREE.Matrix4 {
    const lensOfs = new THREE.Vector3(ax, ay, 0)
    const m = new THREE.Matrix4().lookAt(lensOfs, new THREE.Vector3(0, 0, -focalDistance), new THREE.Vector3(0, 1, 0))
    return new THREE.Matrix4().makeTranslation(lensOfs.x, lensOfs.y, 0).multiply(m)
}

function saveAndMultiplyCameraMatrix(camera: THREE.Camera, m: THREE.Matrix4) {
    camera.updateMatrixWorld()
    const oldP = camera.position.clone()
    const oldQ = camera.quaternion.clone()
    camera.matrixWorld.multiply(m)
    camera.matrixWorld.decompose(camera.position, camera.quaternion, camera.scale)
    return {oldP, oldQ}
}

function restoreCameraMatrix(camera: THREE.Camera, prev: any) {
    camera.position.copy(prev.oldP)
    camera.quaternion.copy(prev.oldQ)
    camera.updateMatrixWorld()
}

export type DepthOfField = {
    apertureSize: number
    focalDistance: number
}
export class SceneViewRenderPass extends Pass {
    private depthOfField: DepthOfField | undefined

    static TAAShader = {
        uniforms: {
            currentImage: {value: null as THREE.Texture | null},
            accumulatedImage: {value: null as THREE.Texture | null},
            avgImage: {value: null as THREE.Texture | null},
            iteration: {value: 0},
        },
        vertexShader: /* glsl */ `
            varying vec2 vUv;
            void main() {
                vUv = uv;
                gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
            }`,
        fragmentShader: /* glsl */ `
            #include <common>
            #define DITHERING
            #include <dithering_pars_fragment>
            uniform sampler2D currentImage;
            uniform sampler2D accumulatedImage;
            uniform int iteration;
            varying vec2 vUv;
            void main() {
                if(iteration < -1) {
                    gl_FragColor = texture2D(accumulatedImage, vUv);
                    gl_FragColor.rgb *= gl_FragColor.a;
                    return;
                } else if(iteration < 0) {
                    gl_FragColor = texture2D(accumulatedImage, vUv);
                    return;
                }

                vec4 currentColor = texture2D(currentImage, vUv);
                currentColor.rgb = dithering(currentColor.rgb);

                if(iteration < 1) {
                    gl_FragColor = currentColor;
                    return;
                } else {
                    vec4 previousColor = texture2D(accumulatedImage, vUv) * float(iteration);
                    gl_FragColor = (previousColor + currentColor) / (float(iteration) + 1.0);
                }
            }`,
    }

    private taaUniforms: (typeof SceneViewRenderPass.TAAShader)["uniforms"]
    private taaQuad: FullScreenQuad

    private sampleRenderTargetTemp: THREE.WebGLRenderTarget | null = null
    private sampleRenderTarget: THREE.WebGLRenderTarget | null = null
    taaMode = false
    private iteration = 0

    renderHelperObjects = true

    constructor(
        private scene: THREE.Scene,
        private camera: THREE.PerspectiveCamera,
        private threeDepthTexture: THREE.DepthTexture,
    ) {
        super()

        this.taaUniforms = THREE.UniformsUtils.clone(SceneViewRenderPass.TAAShader.uniforms)
        const material = new THREE.ShaderMaterial({
            uniforms: this.taaUniforms,
            vertexShader: SceneViewRenderPass.TAAShader.vertexShader,
            fragmentShader: SceneViewRenderPass.TAAShader.fragmentShader,
            depthTest: false,
            depthWrite: false,
        })
        this.taaQuad = new FullScreenQuad(material)
    }

    override dispose() {
        if (this.sampleRenderTargetTemp) {
            this.sampleRenderTargetTemp.dispose()
            this.sampleRenderTargetTemp = null
        }
        if (this.sampleRenderTarget) {
            this.sampleRenderTarget.dispose()
            this.sampleRenderTarget = null
        }

        this.taaQuad.material.dispose()
        this.taaQuad.dispose()
    }

    override setSize(width: number, height: number) {
        if (this.sampleRenderTargetTemp) this.sampleRenderTargetTemp.setSize(width, height)
        if (this.sampleRenderTarget) this.sampleRenderTarget.setSize(width, height)
    }

    override render(
        renderer: THREE.WebGLRenderer,
        writeBuffer: THREE.WebGLRenderTarget,
        readBuffer: THREE.WebGLRenderTarget,
        deltaTime: number,
        maskActive: boolean,
    ) {
        if (!this.sampleRenderTarget) {
            this.sampleRenderTarget = new THREE.WebGLRenderTarget(readBuffer.width, readBuffer.height, {
                minFilter: THREE.NearestFilter,
                magFilter: THREE.NearestFilter,
                format: THREE.RGBAFormat,
                type: DEFAULT_FLOAT_TEXTURE_TYPE,
                depthTexture: this.threeDepthTexture,
            })
            this.sampleRenderTarget.texture.name = "TAARenderPass.snapshot"
        }
        const oldTarget = renderer.getRenderTarget()

        if (!this.taaMode) {
            this.iteration = 0

            renderer.setRenderTarget(this.sampleRenderTarget)
            renderer.render(this.scene, this.camera)
            if (this.renderHelperObjects) renderHelperObjects(renderer, this.scene, this.camera)

            renderer.setRenderTarget(this.renderToScreen ? null : writeBuffer)
            this.taaUniforms.iteration.value = -2
            this.taaUniforms.currentImage.value = null
            this.taaUniforms.accumulatedImage.value = this.sampleRenderTarget.texture
            this.taaQuad.render(renderer)
        } else {
            if (!this.sampleRenderTargetTemp) {
                this.sampleRenderTargetTemp = new THREE.WebGLRenderTarget(readBuffer.width, readBuffer.height, {
                    minFilter: THREE.NearestFilter,
                    magFilter: THREE.NearestFilter,
                    format: THREE.RGBAFormat,
                    type: DEFAULT_FLOAT_TEXTURE_TYPE,
                })
                this.sampleRenderTargetTemp.texture.name = "TAARenderPass.sample"
            }

            const [offsetX, offsetY] = getJitterVector(this.iteration)

            //Step 1: Render current frame to write buffer
            const oldCameraView = saveAndAdjustCameraView(this.camera, readBuffer.width, readBuffer.height, offsetX, offsetY)

            const oldCameraMatrix = (() => {
                if (this.depthOfField) {
                    const {apertureSize, focalDistance} = this.depthOfField
                    const [apertureX, apertureY] = getApertureVector(this.iteration)
                    const apertureM = lensApertureOffsetMatrix(apertureX * apertureSize, apertureY * apertureSize, focalDistance)
                    return saveAndMultiplyCameraMatrix(this.camera, apertureM)
                }
                return undefined
            })()

            renderer.setRenderTarget(writeBuffer)
            renderer.render(this.scene, this.camera)
            if (this.renderHelperObjects) renderHelperObjects(renderer, this.scene, this.camera)

            if (oldCameraMatrix) restoreCameraMatrix(this.camera, oldCameraMatrix)
            restoreCameraView(this.camera, oldCameraView)

            const oldAutoClear = renderer.autoClear
            renderer.autoClear = false

            //Step 2: Blend render current frame to sample render target
            renderer.setRenderTarget(this.sampleRenderTargetTemp)
            this.taaUniforms.currentImage.value = writeBuffer.texture
            this.taaUniforms.accumulatedImage.value = this.sampleRenderTarget.texture
            this.taaUniforms.iteration.value = this.iteration
            this.taaQuad.render(renderer)

            renderer.setRenderTarget(this.sampleRenderTarget)
            this.taaUniforms.iteration.value = -1
            this.taaUniforms.currentImage.value = null
            this.taaUniforms.accumulatedImage.value = this.sampleRenderTargetTemp.texture
            this.taaQuad.render(renderer)

            //Step 3: Copy sample render target to output and apply premultiplied alpha
            renderer.autoClear = oldAutoClear

            renderer.setRenderTarget(this.renderToScreen ? null : writeBuffer)
            this.taaUniforms.iteration.value = -2
            this.taaUniforms.currentImage.value = null
            this.taaUniforms.accumulatedImage.value = this.sampleRenderTarget.texture
            this.taaQuad.render(renderer)

            renderer.autoClear = oldAutoClear

            this.iteration++
        }

        renderer.setRenderTarget(oldTarget)
    }

    setDepthOfField(depthOfField: DepthOfField | undefined) {
        this.depthOfField = depthOfField
    }
}
