import {Pass, FullScreenQuad} from "three/examples/jsm/postprocessing/Pass"
import * as THREE from "three"

export type ThreeImageData = {
    buffer: BufferSource
    width: number
    height: number
    format: THREE.PixelFormat
    type: THREE.TextureDataType
    premultipliedAlpha: boolean
}

export class ImageDataRenderPass extends Pass {
    static Shader = {
        uniforms: {
            tBackground: {value: null as THREE.Texture | null},
            tImage: {value: null as THREE.Texture | null},
            uvTransform: {value: new THREE.Matrix3()},
        },
        vertexShader: /* glsl */ `
            varying vec2 vBgUv;
            varying vec2 vImageUv;
            uniform mat3 uvTransform;
            void main() {
                vBgUv = uv;
                vImageUv = (uvTransform * vec3(uv, 1.0)).xy;
                gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
            }`,
        fragmentShader: /* glsl */ `
            uniform sampler2D tBackground;
            uniform sampler2D tImage;
            varying vec2 vBgUv;
            varying vec2 vImageUv;
            void main() {
                vec4 bg = texture2D(tBackground, vBgUv);
                vec4 img = texture2D(tImage, vImageUv);
                img *= step(0.0, vImageUv.x) * step(vImageUv.x, 1.0) * step(0.0, vImageUv.y) * step(vImageUv.y, 1.0); // mask out the image outside of the UV bounds
                gl_FragColor = img + bg * (1.0 - img.a);
            }`,
    }

    private uniforms: (typeof ImageDataRenderPass.Shader)["uniforms"]
    private quad: FullScreenQuad

    private imageTexture: THREE.DataTexture = new THREE.DataTexture(new Uint8Array(4), 1, 1, THREE.RGBAFormat, THREE.UnsignedByteType)
    private destinationRect: [number, number, number, number] = [0, 0, 1, 1]
    private canvasSize: [number, number] = [1, 1]

    constructor() {
        super()

        this.uniforms = THREE.UniformsUtils.clone(ImageDataRenderPass.Shader.uniforms)
        const material = new THREE.ShaderMaterial({
            uniforms: this.uniforms,
            vertexShader: ImageDataRenderPass.Shader.vertexShader,
            fragmentShader: ImageDataRenderPass.Shader.fragmentShader,
            depthTest: false,
            depthWrite: false,
        })

        this.quad = new FullScreenQuad(material)

        this.updateUVTransform()
    }

    setImageData(data: ThreeImageData | null) {
        if (this.imageTexture) {
            this.imageTexture.dispose()
        }
        if (!data) {
            this.imageTexture = new THREE.DataTexture(new Uint8Array(4), 1, 1, THREE.RGBAFormat, THREE.UnsignedByteType)
            return
        }
        this.imageTexture = new THREE.DataTexture(data.buffer, data.width, data.height, data.format, data.type)
        this.imageTexture.minFilter = THREE.LinearFilter
        this.imageTexture.magFilter = THREE.LinearFilter
        this.imageTexture.wrapS = THREE.ClampToEdgeWrapping
        this.imageTexture.wrapT = THREE.ClampToEdgeWrapping
        this.imageTexture.needsUpdate = true
        this.imageTexture.premultiplyAlpha = !data.premultipliedAlpha // convert to premultiplied alpha for the shader
    }

    private updateUVTransform() {
        // Note that the UV origin is the bottom-left corner, while the image origin is the top-left corner!
        const [canvasW, canvasH] = this.canvasSize
        const [destX, destY, destW, destH] = this.destinationRect
        const scaleX = canvasW / destW
        const scaleY = canvasH / destH
        this.uniforms.uvTransform.value.set(scaleX, 0, -destX / destW, 0, -scaleY, destY / destH + (1 - scaleY) + 1, 0, 0, 1)
    }

    override dispose() {
        this.quad.material.dispose()
        this.quad.dispose()
    }

    override setSize(width: number, height: number) {
        this.canvasSize = [width, height]
        this.destinationRect = [0, 0, width, height]
        this.updateUVTransform()
    }

    override render(
        renderer: THREE.WebGLRenderer,
        writeBuffer: THREE.WebGLRenderTarget,
        readBuffer: THREE.WebGLRenderTarget,
        deltaTime: number,
        maskActive: boolean,
    ) {
        const oldTarget = renderer.getRenderTarget()

        renderer.setRenderTarget(this.renderToScreen ? null : writeBuffer)
        this.uniforms.tBackground.value = readBuffer.texture
        this.uniforms.tImage.value = this.imageTexture
        this.quad.render(renderer)

        renderer.setRenderTarget(oldTarget)
    }
}
