import * as THREE from "three"
import potpack, {PotpackBox} from "potpack"
import {updateTransform} from "@template-editor/helpers/three-object"
import {DEFAULT_FLOAT_TEXTURE_TYPE} from "@template-editor/helpers/three-utils"
import {Subject} from "rxjs"
import {getJitterVector} from "./jitter"

type UVBounds = {
    minU: number
    minV: number
    maxU: number
    maxV: number
}

const getMinMaxUVs = (uv: THREE.BufferAttribute) => {
    if (uv.itemSize !== 2) throw new Error("UV item size must be 2")
    if (uv.count === 0) return {minU: 0, minV: 0, maxU: 0, maxV: 0} as UVBounds
    let minU = Infinity
    let minV = Infinity
    let maxU = -Infinity
    let maxV = -Infinity
    for (let i = 0; i < uv.array.length; i += uv.itemSize) {
        const u = uv.array[i]
        const v = uv.array[i + 1]
        minU = Math.min(minU, u)
        minV = Math.min(minV, v)
        maxU = Math.max(maxU, u)
        maxV = Math.max(maxV, v)
    }
    return {minU, minV, maxU, maxV} as UVBounds
}

const scaleUVs = (uv: THREE.BufferAttribute, scale: number) => {
    for (let i = 0; i < uv.array.length; i += uv.itemSize) {
        uv.array[i] *= scale
        uv.array[i + 1] *= scale
    }
}

const ensureUVSideness = (uv: THREE.BufferAttribute, indexBuffer: THREE.BufferAttribute | null, side: typeof THREE.FrontSide | typeof THREE.BackSide) => {
    if (!indexBuffer) throw new Error("Index is missing")
    if (indexBuffer.array.length === 0) throw new Error("Index is empty")
    if (indexBuffer.array.length % 3 !== 0) throw new Error("Index is not a multiple of 3")

    let backFacing = false
    for (let i = 0; i < indexBuffer.array.length; i += 3) {
        const a = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i])
        const b = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 1])
        const c = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 2])
        const det = a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y)

        if (det < 0.0) backFacing = true

        break
    }

    if (side === THREE.FrontSide && !backFacing) return
    if (side === THREE.BackSide && backFacing) return

    for (let i = 0; i < uv.array.length; i += uv.itemSize) uv.array[i] *= -1
}

const getUVSurfaceArea = (uv: THREE.BufferAttribute, indexBuffer: THREE.BufferAttribute | null) => {
    if (!indexBuffer) throw new Error("Index is missing")
    if (indexBuffer.array.length === 0) throw new Error("Index is empty")
    if (indexBuffer.array.length % 3 !== 0) throw new Error("Index is not a multiple of 3")

    const getTriangleArea = (v0: THREE.Vector2, v1: THREE.Vector2, v2: THREE.Vector2) => {
        const a = v0.distanceTo(v1)
        const b = v1.distanceTo(v2)
        const c = v2.distanceTo(v0)
        const s = (a + b + c) / 2
        return Math.sqrt(s * (s - a) * (s - b) * (s - c))
    }

    let area = 0
    for (let i = 0; i < indexBuffer.array.length; i += 3) {
        const v0 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i])
        const v1 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 1])
        const v2 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 2])
        const faceArea = getTriangleArea(v0, v1, v2)
        if (faceArea <= 0 || !isFinite(faceArea)) continue
        area += faceArea
    }

    return area
}

const getGeometrySurfaceArea = (positions: THREE.BufferAttribute | THREE.InterleavedBufferAttribute, indexBuffer: THREE.BufferAttribute | null) => {
    if (!indexBuffer) throw new Error("Index is missing")
    if (indexBuffer.array.length === 0) throw new Error("Index is empty")
    if (indexBuffer.array.length % 3 !== 0) throw new Error("Index is not a multiple of 3")

    const getTriangleArea = (v0: THREE.Vector3, v1: THREE.Vector3, v2: THREE.Vector3) => {
        const a = v0.distanceTo(v1)
        const b = v1.distanceTo(v2)
        const c = v2.distanceTo(v0)
        const s = (a + b + c) / 2
        return Math.sqrt(s * (s - a) * (s - b) * (s - c))
    }

    let area = 0
    for (let i = 0; i < indexBuffer.array.length; i += 3) {
        const v0 = new THREE.Vector3().fromBufferAttribute(positions, indexBuffer.array[i])
        const v1 = new THREE.Vector3().fromBufferAttribute(positions, indexBuffer.array[i + 1])
        const v2 = new THREE.Vector3().fromBufferAttribute(positions, indexBuffer.array[i + 2])
        const faceArea = getTriangleArea(v0, v1, v2)
        if (faceArea <= 0 || !isFinite(faceArea)) continue
        area += faceArea
    }

    return area
}

const uvOverlappingRatio = (uv: THREE.BufferAttribute, indexBuffer: THREE.BufferAttribute | null, rasterResolution: number) => {
    if (!indexBuffer) throw new Error("Index is missing")
    if (indexBuffer.array.length === 0) throw new Error("Index is empty")
    if (indexBuffer.array.length % 3 !== 0) throw new Error("Index is not a multiple of 3")

    const {minU, minV, maxU, maxV} = getMinMaxUVs(uv)

    const uSize = maxU - minU
    const vSize = maxV - minV
    const uvToRaster = (1 / Math.max(uSize, vSize)) * rasterResolution

    const resX = Math.ceil(uSize * uvToRaster)
    const resY = Math.ceil(vSize * uvToRaster)

    const rasterizedAtlas = new Uint8Array(resX * resY)

    let totalPixels = 0
    let totalOverlappingPixels = 0
    for (let i = 0; i < indexBuffer.array.length; i += 3) {
        const v0 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i])
        const v1 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 1])
        const v2 = new THREE.Vector2().fromBufferAttribute(uv, indexBuffer.array[i + 2])

        const p0 = new THREE.Vector3((v0.x - minU) * uvToRaster, (v0.y - minV) * uvToRaster)
        const p1 = new THREE.Vector3((v1.x - minU) * uvToRaster, (v1.y - minV) * uvToRaster)
        const p2 = new THREE.Vector3((v2.x - minU) * uvToRaster, (v2.y - minV) * uvToRaster)

        const triangle = new THREE.Triangle(p0, p1, p2)

        const xMin = Math.floor(Math.min(p0.x, p1.x, p2.x))
        const xMax = Math.ceil(Math.max(p0.x, p1.x, p2.x))
        const yMin = Math.floor(Math.min(p0.y, p1.y, p2.y))
        const yMax = Math.ceil(Math.max(p0.y, p1.y, p2.y))

        for (let y = yMin; y <= yMax && y < resY; y++) {
            for (let x = xMin; x <= xMax && x < resX; x++) {
                const p = new THREE.Vector3(x + 0.5, y + 0.5)

                const baryCoords = new THREE.Vector3()
                if (triangle.getBarycoord(p, baryCoords) !== null) {
                    if (baryCoords.x > 0 && baryCoords.y > 0 && baryCoords.z > 0) {
                        const index = y * resX + x
                        const value = rasterizedAtlas[index]
                        if (value === 0) {
                            totalPixels++
                            rasterizedAtlas[index] = 1
                        } else if (value === 1) {
                            totalOverlappingPixels++
                            rasterizedAtlas[index] = 2
                        }
                    }
                }
            }
        }
    }

    if (totalPixels === 0) return 0
    return totalOverlappingPixels / totalPixels
}

const getShadowAvg = /* glsl */ `
#if ( NUM_POINT_LIGHT_SHADOWS > 0 )
    uniform bool pointIsAreaLight[NUM_POINT_LIGHT_SHADOWS];
    uniform vec3 areaLightDirections[NUM_POINT_LIGHT_SHADOWS];
#endif

float getShadowAvg(vec3 geometryPosition, vec3 geometryNormal) {
    float shadowMask = 1.0;

    #ifdef USE_SHADOWMAP

    #if NUM_POINT_LIGHT_SHADOWS > 0

    PointLight pointLight;
    PointLightShadow pointLightShadow;
    IncidentLight directLight;
    float attenuation;
    float dotNL;
    float weight;
    float shadow;
    float totalWeight = 0.0;
    float totalShadow = 0.0;

    #pragma unroll_loop_start
    for (int i = 0; i < NUM_POINT_LIGHT_SHADOWS; i++) {
        pointLight = pointLights[i];

        getPointLightInfo(pointLight, geometryPosition, directLight);
        dotNL = saturate(dot(geometryNormal, directLight.direction));

        if(pointIsAreaLight[i])
            dotNL *= saturate(dot(areaLightDirections[i], directLight.direction));

        attenuation = (directLight.color.r + directLight.color.g + directLight.color.b) / 3.0;
        weight = dotNL * attenuation;

        pointLightShadow = pointLightShadows[i];
        shadow = (directLight.visible && receiveShadow) ? getPointShadow(pointShadowMap[i], pointLightShadow.shadowMapSize, pointLightShadow.shadowBias, pointLightShadow.shadowRadius, vPointShadowCoord[i], pointLightShadow.shadowCameraNear, pointLightShadow.shadowCameraFar) : 1.0;

        totalShadow += shadow * weight;
        totalWeight += weight;
    }
    #pragma unroll_loop_end

    if(totalWeight > 0.00001)
        shadowMask = totalShadow / totalWeight;
    else
        shadowMask = 1.0;

    #endif

    #endif

    return shadowMask;
}
`

class RectAreaLightProxy extends THREE.PointLight {}

type OriginalObjectRenderData = {
    parent: THREE.Object3D | null
}

const getOriginalObjectRenderData = (object: THREE.Object3D): OriginalObjectRenderData => {
    return {
        parent: object.parent,
    }
}

const restoreOriginalObjectRenderData = (object: THREE.Object3D, originalData: OriginalObjectRenderData) => {
    const {parent} = originalData
    if (parent) parent.attach(object)
    else object.parent = null
}

type OriginalMeshRenderData = {
    material: THREE.Material | THREE.Material[]
    visible: boolean
    frustrumCulled: boolean
}

const getOriginalMeshRenderData = (mesh: THREE.Mesh): OriginalMeshRenderData => {
    return {
        material: mesh.material,
        visible: mesh.visible,
        frustrumCulled: mesh.frustumCulled,
    }
}

const restoreOriginalMeshRenderData = (mesh: THREE.Mesh, originalData: OriginalMeshRenderData) => {
    const {material, visible, frustrumCulled} = originalData

    mesh.material = material
    mesh.visible = visible
    mesh.frustumCulled = frustrumCulled
}

const getShadowMaterial = (side: THREE.Side) => {
    const material = new THREE.ShaderMaterial({
        lights: true,
        uniforms: {
            ...THREE.UniformsLib.lights,
            previousUVShadowMap: {value: null},
            iteration: {value: null},
            pointIsAreaLight: {value: []},
            areaLightDirections: {value: []},
        },
        vertexShader: /* glsl */ `
            attribute vec2 uv3;

            varying vec3 vViewPosition;
            varying vec2 vShadowMapUv;
            
            #include <common>
            #include <normal_pars_vertex>
            #include <shadowmap_pars_vertex>

            void main() {
                #include <beginnormal_vertex>
                #include <defaultnormal_vertex>
                #include <normal_vertex>

                #include <begin_vertex>
                #include <project_vertex>

                vViewPosition = -mvPosition.xyz;

                #include <worldpos_vertex>
                #include <shadowmap_vertex>

                vShadowMapUv = uv3;

                gl_Position = vec4((vShadowMapUv - 0.5) * 2.0, 1.0, 1.0);
            }`,
        fragmentShader: /* glsl */ `
            #include <common>
            #include <packing>
            #include <lights_pars_begin>
            #include <normal_pars_fragment>
            #include <shadowmap_pars_fragment>

            ${getShadowAvg}

            uniform sampler2D previousUVShadowMap;
            uniform int iteration;
            varying vec3 vViewPosition;
            varying vec2 vShadowMapUv;

            void main() {
                if(iteration < 0) {
                    gl_FragColor = texture2D(previousUVShadowMap, vShadowMapUv);
                    return;
                }

                #include <normal_fragment_begin>
                vec3 shadowMask = vec3(getShadowAvg(-vViewPosition, normal));

                if(iteration < 1) {
                    gl_FragColor = vec4(shadowMask, 1.0);
                } else {
                    vec3 previousValue = texture2D(previousUVShadowMap, vShadowMapUv).rgb * float(iteration);
                    gl_FragColor = vec4((previousValue + shadowMask) / (float(iteration) + 1.0), 1.0);
                }
            }`,
    })

    material.side = side
    return material
}

const getMaskMaterial = (side: THREE.Side) => {
    const material = new THREE.ShaderMaterial({
        vertexShader: /* glsl */ `
            attribute vec2 uv3;

            varying vec2 vShadowMapUv;

            void main() {
                vShadowMapUv = uv3;

                gl_Position = vec4((vShadowMapUv - 0.5) * 2.0, 1.0, 1.0);
            }`,
        fragmentShader: /* glsl */ `
            varying vec2 vShadowMapUv;

            void main() {
                gl_FragColor = vec4(1.0, 1.0, 1.0, 1.0);
            }`,
    })

    material.side = side
    return material
}

const getDummyMaterial = (side: THREE.Side) => {
    const material = new THREE.MeshBasicMaterial()
    material.side = side
    return material
}

const getDilateMaterial = () => {
    const material = new THREE.ShaderMaterial({
        uniforms: {
            previousUVShadowMap: {value: null},
            shadowMapMask: {value: null},
            pixelOffsetU: {value: null},
            pixelOffsetV: {value: null},
        },
        vertexShader: /* glsl */ `
            varying vec2 vUv;
    
            void main() {
                #include <begin_vertex>
                #include <project_vertex>
                #include <worldpos_vertex>

                vUv = uv;
            }`,
        fragmentShader: /* glsl */ `
            varying vec2 vUv;

            uniform sampler2D previousUVShadowMap;
            uniform sampler2D shadowMapMask;
            uniform float pixelOffsetU;
            uniform float pixelOffsetV;

            bool isValid(vec2 uv) {
                return texture2D(shadowMapMask, uv).r > 0.0;
            }
    
            void main() {
                if(isValid(vUv)) {
                    gl_FragColor = texture2D(previousUVShadowMap, vUv);
                    return;
                }

                uint numValidNeighbors = 0u;
                vec4 neighborsSum = vec4(0.0);

                int neighborsRegion = 1;
                for(int j = -neighborsRegion; j <= neighborsRegion; j++) {
                    for(int i = -neighborsRegion; i <= neighborsRegion; i++) {
                        if(i == 0 && j == 0) continue;

                        vec2 uvNeighbor = vUv + vec2(float(i) * pixelOffsetU, float(j) * pixelOffsetV);
                        if(isValid(uvNeighbor)) {
                            neighborsSum += texture2D(previousUVShadowMap, uvNeighbor);
                            numValidNeighbors++;
                        }
                    }
                }

                if(numValidNeighbors > 0u)
                    gl_FragColor = neighborsSum / float(numValidNeighbors);
                else
                    gl_FragColor = vec4(0.5, 0.5, 0.5, 1.0);
            }`,
    })

    return material
}

export class ThreeProgressiveUVShadowMap {
    private uvShadowMapTemp: THREE.WebGLRenderTarget
    private uvShadowMap: THREE.WebGLRenderTarget
    private shadowMapMask: THREE.WebGLRenderTarget
    private shadowMapMaskDirty = false
    private dummyRenderTarget: THREE.WebGLRenderTarget
    private shadowMaterials: {[key in THREE.Side]: THREE.ShaderMaterial}
    private maskMaterials: {[key in THREE.Side]: THREE.ShaderMaterial}
    private dummyMaterials: {[key in THREE.Side]: THREE.MeshBasicMaterial}
    private dilateMaterial = getDilateMaterial()
    private meshes: Set<THREE.Mesh<THREE.BufferGeometry, THREE.Material>> = new Set()
    private rectAreaLightProxyMap = new Map<THREE.RectAreaLight, RectAreaLightProxy>()
    private debugMaterial: THREE.MeshBasicMaterial
    private camera = new THREE.PerspectiveCamera()
    private debugObject: THREE.Mesh<THREE.PlaneGeometry, THREE.MeshBasicMaterial> | undefined = undefined
    private resolutionX: number | undefined
    private resolutionY: number | undefined
    private planeGeometry = new THREE.PlaneGeometry(1.0, 1.0)

    private detachedUVShadowMap: Subject<THREE.Mesh<THREE.BufferGeometry, THREE.Material>> = new Subject<THREE.Mesh<THREE.BufferGeometry, THREE.Material>>()
    detachedUVShadowMap$ = this.detachedUVShadowMap.asObservable()

    private attachedUVShadowMap: Subject<THREE.Mesh<THREE.BufferGeometry, THREE.Material>> = new Subject<THREE.Mesh<THREE.BufferGeometry, THREE.Material>>()
    attachedUVShadowMap$ = this.attachedUVShadowMap.asObservable()

    constructor(
        private renderer: THREE.WebGLRenderer,
        private resolution: number,
    ) {
        this.uvShadowMapTemp = new THREE.WebGLRenderTarget(1, 1, {
            type: DEFAULT_FLOAT_TEXTURE_TYPE,
            generateMipmaps: false,
            minFilter: THREE.NearestFilter,
            magFilter: THREE.NearestFilter,
            format: THREE.RedFormat,
            colorSpace: THREE.NoColorSpace,
        })
        this.uvShadowMap = new THREE.WebGLRenderTarget(1, 1, {
            type: DEFAULT_FLOAT_TEXTURE_TYPE,
            generateMipmaps: false,
            minFilter: THREE.LinearFilter,
            magFilter: THREE.LinearFilter,
            format: THREE.RedFormat,
            colorSpace: THREE.NoColorSpace,
        })
        this.uvShadowMap.texture.channel = 3

        this.shadowMapMask = new THREE.WebGLRenderTarget(1, 1, {
            generateMipmaps: false,
            minFilter: THREE.NearestFilter,
            magFilter: THREE.NearestFilter,
            format: THREE.RedFormat,
            colorSpace: THREE.NoColorSpace,
        })

        this.dummyRenderTarget = new THREE.WebGLRenderTarget(1, 1)

        this.debugMaterial = new THREE.MeshBasicMaterial({map: this.uvShadowMap.texture})

        this.shadowMaterials = {
            [THREE.FrontSide]: getShadowMaterial(THREE.FrontSide),
            [THREE.BackSide]: getShadowMaterial(THREE.BackSide),
            [THREE.DoubleSide]: getShadowMaterial(THREE.DoubleSide),
        }

        this.maskMaterials = {
            [THREE.FrontSide]: getMaskMaterial(THREE.FrontSide),
            [THREE.BackSide]: getMaskMaterial(THREE.BackSide),
            [THREE.DoubleSide]: getMaskMaterial(THREE.DoubleSide),
        }

        this.dummyMaterials = {
            [THREE.FrontSide]: getDummyMaterial(THREE.FrontSide),
            [THREE.BackSide]: getDummyMaterial(THREE.BackSide),
            [THREE.DoubleSide]: getDummyMaterial(THREE.DoubleSide),
        }

        this.reset()
    }

    has(mesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>) {
        return this.meshes.has(mesh)
    }

    dispose() {
        for (const oldMesh of this.meshes) {
            oldMesh.geometry.deleteAttribute("uv3")
            this.detachedUVShadowMap.next(oldMesh)
        }
        for (const light of this.rectAreaLightProxyMap.values()) light.dispose()
        this.dilateMaterial.dispose()
        for (const dummyMaterial of Object.values(this.dummyMaterials)) dummyMaterial.dispose()
        for (const maskMaterial of Object.values(this.maskMaterials)) maskMaterial.dispose()
        for (const shadowMaterial of Object.values(this.shadowMaterials)) shadowMaterial.dispose()
        this.debugMaterial.dispose()
        this.dummyRenderTarget.dispose()
        this.shadowMapMask.dispose()
        this.uvShadowMap.dispose()
        this.uvShadowMapTemp.dispose()
        this.planeGeometry.dispose()
    }

    getUVShadowMap() {
        return this.uvShadowMap.texture
    }

    attachToMeshes(newMeshes: Set<THREE.Mesh<THREE.BufferGeometry, THREE.Material>>) {
        for (const oldMesh of this.meshes) {
            oldMesh.geometry.deleteAttribute("uv3")
            if (!newMeshes.has(oldMesh)) this.detachedUVShadowMap.next(oldMesh)
        }

        const uvBoxes: (PotpackBox & {
            mesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>
            uvBounds: UVBounds
        })[] = []

        for (const newMesh of newMeshes) {
            const existingUvs = newMesh.geometry.getAttribute("uv")
            if (!existingUvs || existingUvs.count === 0) {
                console.warn("No UVs found on mesh, skipping", newMesh)
                continue
            }

            const existingUv3s = newMesh.geometry.getAttribute("uv3")
            if (existingUv3s) {
                console.warn("UV with index 3 already exist on mesh, skipping", newMesh)
                continue
            }

            const surfaceArea = getGeometrySurfaceArea(newMesh.geometry.getAttribute("position"), newMesh.geometry.getIndex())
            if (surfaceArea <= 0 || !isFinite(surfaceArea)) {
                console.warn("Mesh surface is invalid, skipping", newMesh)
                continue
            }

            const uv3 = existingUvs.clone()

            const overlappingRatio = uvOverlappingRatio(uv3, newMesh.geometry.getIndex(), 64)
            if (overlappingRatio > 0.1) {
                console.warn(`Mesh uv suface is overlapping (ratio: ${overlappingRatio}), skipping`, newMesh)
                continue
            }

            const {side} = newMesh.material
            if (side !== THREE.DoubleSide) ensureUVSideness(uv3, newMesh.geometry.getIndex(), side)

            const uvArea = getUVSurfaceArea(uv3, newMesh.geometry.getIndex())

            if (uvArea <= 0 || !isFinite(uvArea)) {
                console.warn("Mesh uv suface is invalid, skipping", newMesh)
                continue
            }

            const surfaceScale = Math.sqrt(surfaceArea)
            const uvScale = Math.sqrt(uvArea)
            const scale = surfaceScale / uvScale

            if (scale <= 0 || !isFinite(scale)) {
                console.warn("Mesh scaling is invalid, skipping", newMesh)
                continue
            }

            newMesh.geometry.setAttribute("uv3", uv3)
            scaleUVs(uv3, scale)
            const uvBounds = getMinMaxUVs(uv3)

            uvBoxes.push({
                w: uvBounds.maxU - uvBounds.minU,
                h: uvBounds.maxV - uvBounds.minV,
                mesh: newMesh,
                uvBounds,
            })

            this.attachedUVShadowMap.next(newMesh)
        }
        this.meshes = new Set(uvBoxes.map((box) => box.mesh))

        const totalBoxArea = uvBoxes.reduce((acc, box) => acc + box.w * box.h, 0)
        const padding = Math.sqrt(totalBoxArea) * 0.01 //1% padding
        for (const uvBox of uvBoxes) {
            uvBox.w += padding * 2
            uvBox.h += padding * 2
        }

        const dimensions = potpack(uvBoxes)

        uvBoxes.forEach(({mesh, x, y, uvBounds}) => {
            const uv3 = mesh.geometry.getAttribute("uv3")
            if (x === undefined || y === undefined) throw new Error("x or y is undefined")

            const {minU, minV} = uvBounds

            for (let i = 0; i < uv3.array.length; i += uv3.itemSize) {
                uv3.array[i] = (uv3.array[i] - minU + x + padding) / dimensions.w
                uv3.array[i + 1] = (uv3.array[i + 1] - minV + y + padding) / dimensions.h
                if (uv3.array[i] >= 0 && uv3.array[i] <= 1 && uv3.array[i + 1] >= 0 && uv3.array[i + 1] <= 1) {
                } else {
                    console.log("uv3 out of bounds", uv3.array[i], uv3.array[i + 1])
                }
            }

            uv3.needsUpdate = true
        })

        const resolutionSq = this.resolution * this.resolution
        const dimensionSq = dimensions.w * dimensions.h
        const scale = Math.sqrt(resolutionSq / dimensionSq)

        this.resolutionX = dimensions.w * scale
        this.resolutionY = dimensions.h * scale

        this.uvShadowMapTemp.setSize(this.resolutionX, this.resolutionY)
        this.uvShadowMap.setSize(this.resolutionX, this.resolutionY)
        this.shadowMapMask.setSize(this.resolutionX, this.resolutionY)
        this.shadowMapMaskDirty = true

        this.reset()

        if (this.debugObject) this.debugObject.scale.set(1, this.resolutionY / this.resolutionX, 1)
    }

    update(scene: THREE.Scene, currentIteration: number, iterations = 1) {
        if (this.meshes.size === 0) return
        if (this.resolutionX === undefined || this.resolutionY === undefined) return

        const tempScene = new THREE.Scene()

        const createPointLightProxy = () => {
            const threeLight = new RectAreaLightProxy()
            threeLight.castShadow = true
            threeLight.shadow.mapSize.width = 1024
            threeLight.shadow.mapSize.height = 1024
            threeLight.shadow.camera.near = 10
            threeLight.shadow.camera.far = 2000
            //threeLight.shadow.bias = 1.0 / 1024.0
            return threeLight
        }

        const lightObjects: (THREE.Light | THREE.Mesh)[] = []
        scene.traverseVisible((object) => {
            if (object instanceof THREE.Light || object instanceof THREE.Mesh) lightObjects.push(object)
        })

        const pointIsAreaLight: boolean[] = []
        const areaLightDirections: THREE.Vector3[] = []
        const pointLightHelpersToDispose = new Set(this.rectAreaLightProxyMap.keys())
        const originalObjectData = new Map<THREE.Object3D, OriginalObjectRenderData>()
        lightObjects.forEach((object) => {
            if (object instanceof THREE.Light) {
                if (!object.castShadow) return

                if (object instanceof THREE.RectAreaLight) {
                    const lightGroup = new THREE.Group()
                    updateTransform(object.matrixWorld, lightGroup)

                    const lightProxy = this.rectAreaLightProxyMap.get(object) ?? createPointLightProxy()
                    this.rectAreaLightProxyMap.set(object, lightProxy)
                    pointLightHelpersToDispose.delete(object)

                    const {width, height, visible, intensity, color} = object

                    lightProxy.visible = visible
                    lightProxy.color.set(color)
                    lightProxy.intensity = intensity * width * height

                    const direction = new THREE.Vector3(0, 0, 1).transformDirection(object.matrixWorld)

                    lightGroup.add(lightProxy)

                    pointIsAreaLight.push(true)
                    areaLightDirections.push(direction.transformDirection(this.camera.matrixWorldInverse))

                    tempScene.add(lightGroup)
                    return
                }

                if (object instanceof THREE.PointLight) {
                    pointIsAreaLight.push(false)
                    areaLightDirections.push(new THREE.Vector3())
                }
            } else {
                if (!object.castShadow && !object.receiveShadow) return
            }

            originalObjectData.set(object, getOriginalObjectRenderData(object))
            tempScene.attach(object)
        })

        for (const light of pointLightHelpersToDispose.values()) {
            const lightProxy = this.rectAreaLightProxyMap.get(light)
            if (!lightProxy) throw new Error("Reference light not found")

            lightProxy.dispose()
            this.rectAreaLightProxyMap.delete(light)
        }
        pointLightHelpersToDispose.clear()

        const oldTarget = this.renderer.getRenderTarget()
        const oldShadowMapEnabled = this.renderer.shadowMap.enabled
        const oldShadowMapAutoUpdate = this.renderer.shadowMap.autoUpdate
        const oldClearColor = new THREE.Color()
        this.renderer.getClearColor(oldClearColor)
        const oldClearAlpha = this.renderer.getClearAlpha()

        const originalMeshData = new Map<THREE.Mesh, OriginalMeshRenderData>()
        lightObjects.forEach((object) => {
            if (object instanceof THREE.Mesh) {
                const mesh = object as THREE.Mesh
                originalMeshData.set(mesh, getOriginalMeshRenderData(mesh))
            }
        })

        this.renderer.shadowMap.enabled = true
        this.renderer.setClearColor(0xffffff, 1)

        for (let i = 0; i < iterations; i++) {
            //1. Pass updates the internal Three.JS shadow maps using all meshes
            this.renderer.shadowMap.autoUpdate = false
            this.renderer.shadowMap.needsUpdate = true
            this.renderer.setRenderTarget(this.dummyRenderTarget)

            lightObjects.forEach((object) => {
                if (object instanceof THREE.Mesh) {
                    const mesh = object as THREE.Mesh

                    const originalData = originalMeshData.get(mesh)
                    if (!originalData) throw new Error("Original mesh data not found")

                    restoreOriginalMeshRenderData(mesh, originalData)

                    if (Array.isArray(mesh.material)) mesh.material = mesh.material.map((material) => this.dummyMaterials[material.side])
                    else mesh.material = this.dummyMaterials[mesh.material.side]
                }
            })

            const [offsetX, offsetY] = getJitterVector(currentIteration + i)

            for (const [light, lightProxy] of this.rectAreaLightProxyMap) {
                const {width, height} = light
                lightProxy.position.x = offsetX * width
                lightProxy.position.y = offsetY * height
            }

            this.renderer.render(tempScene, this.camera)

            //2. Pass updates the uv shadow map with only the meshes that receive shadows
            lightObjects.forEach((object) => {
                if (object instanceof THREE.Mesh) {
                    if (object.receiveShadow && this.meshes.has(object as THREE.Mesh<THREE.BufferGeometry, THREE.Material>)) {
                        const mesh = object as THREE.Mesh<THREE.BufferGeometry, THREE.Material>

                        mesh.frustumCulled = false
                        mesh.material = this.shadowMaterials[mesh.material.side]
                    } else object.visible = false
                }
            })

            this.renderer.setRenderTarget(this.uvShadowMapTemp)

            for (const shadowMaterial of Object.values(this.shadowMaterials)) {
                shadowMaterial.uniforms.iteration.value = currentIteration + i
                shadowMaterial.uniforms.previousUVShadowMap.value = this.uvShadowMap.texture
                shadowMaterial.uniforms.pointIsAreaLight.value = pointIsAreaLight
                shadowMaterial.uniforms.areaLightDirections.value = areaLightDirections
            }

            this.renderer.render(tempScene, this.camera)

            //The last iteration is not needed, as this is done in the final filter pass
            if (i !== iterations - 1) {
                this.renderer.setRenderTarget(this.uvShadowMap)

                for (const shadowMaterial of Object.values(this.shadowMaterials)) {
                    shadowMaterial.uniforms.iteration.value = -1
                    shadowMaterial.uniforms.previousUVShadowMap.value = this.uvShadowMapTemp.texture
                }

                this.renderer.render(tempScene, this.camera)
            }
        }

        if (this.shadowMapMaskDirty) {
            this.renderer.setClearColor(0x000000, 1)
            this.renderer.setRenderTarget(this.shadowMapMask)

            lightObjects.forEach((object) => {
                if (object instanceof THREE.Mesh) {
                    if (object.receiveShadow && this.meshes.has(object as THREE.Mesh<THREE.BufferGeometry, THREE.Material>)) {
                        const mesh = object as THREE.Mesh<THREE.BufferGeometry, THREE.Material>
                        mesh.material = this.maskMaterials[mesh.material.side]
                    }
                }
            })

            this.renderer.render(tempScene, this.camera)
            this.shadowMapMaskDirty = false
        }

        originalMeshData.forEach((originalData, mesh) => restoreOriginalMeshRenderData(mesh, originalData))
        originalObjectData.forEach((originalData, object) => restoreOriginalObjectRenderData(object, originalData))

        this.renderer.setRenderTarget(oldTarget)
        this.renderer.setClearColor(oldClearColor, oldClearAlpha)
        this.renderer.shadowMap.autoUpdate = oldShadowMapAutoUpdate
        this.renderer.shadowMap.enabled = oldShadowMapEnabled

        //Dilate to fill in gaps
        this.dilateMaterial.uniforms.previousUVShadowMap.value = this.uvShadowMapTemp.texture
        this.dilateMaterial.uniforms.shadowMapMask.value = this.shadowMapMask.texture
        this.dilateMaterial.uniforms.pixelOffsetU.value = 1.0 / this.resolutionX
        this.dilateMaterial.uniforms.pixelOffsetV.value = 1.0 / this.resolutionY
        this.fillRenderTargetWithMaterial(this.renderer, this.uvShadowMap, this.dilateMaterial)
    }

    getDebugObject() {
        if (!this.debugObject) {
            const labelMaterial = new THREE.MeshBasicMaterial({map: this.uvShadowMapTemp.texture, side: THREE.DoubleSide})
            const labelPlane = new THREE.PlaneGeometry(100, 100)
            this.debugObject = new THREE.Mesh(labelPlane, labelMaterial)
            this.debugObject.position.y = 250
            if (this.resolutionX && this.resolutionY) this.debugObject.scale.set(1, this.resolutionY / this.resolutionX, 1)
            return this.debugObject
        } else return this.debugObject
    }

    reset() {
        this.fillRenderTargetWithColor(this.renderer, this.uvShadowMapTemp, new THREE.Color(0xffffff))
        this.fillRenderTargetWithColor(this.renderer, this.uvShadowMap, new THREE.Color(0xffffff))
        this.fillRenderTargetWithColor(this.renderer, this.shadowMapMask, new THREE.Color(0x000000))
    }

    private fillRenderTargetWithColor(renderer: THREE.WebGLRenderer, renderTarget: THREE.WebGLRenderTarget, color: THREE.Color) {
        const material = new THREE.MeshBasicMaterial({color: color})
        this.fillRenderTargetWithMaterial(renderer, renderTarget, material)
        material.dispose()
    }

    private fillRenderTargetWithMaterial(renderer: THREE.WebGLRenderer, renderTarget: THREE.WebGLRenderTarget, material: THREE.Material) {
        const renderScene = new THREE.Scene()

        const planeMesh = new THREE.Mesh(this.planeGeometry, material)
        renderScene.add(planeMesh)

        const renderCamera = new THREE.OrthographicCamera(-0.5, 0.5, 0.5, -0.5, 0.1, 2)
        renderCamera.position.z = 1

        const previousRenderTarget = renderer.getRenderTarget()
        renderer.setRenderTarget(renderTarget)
        renderer.render(renderScene, renderCamera)

        renderer.setRenderTarget(previousRenderTarget)
    }
}
