import {Matrix4, Vector3} from "@cm/math"
import {Color, colorEqual, SceneNodes, GlobalRenderConstants} from "@cm/template-nodes"
import {IScene, ThreeObjectBase} from "@editor/helpers/scene/three-proxies/utils"
import {toThreeVector, fromThreeVector} from "@template-editor/helpers/three-utils"
import {Three as THREE} from "@cm/material-nodes/three"

export class ThreeAreaLight extends ThreeObjectBase {
    threeObject: THREE.Group
    override threeHelperObject?: ThreeProxyLightHelper | null
    private light?: DirectionalPointLight
    _width = 1
    _height = 1
    _color: Color = [1, 1, 1]
    _on = true
    _intensity = 1
    _directionality = 1
    _target = new Vector3(0, 0, 0)

    constructor(scene: IScene) {
        super(scene)

        this.light = new DirectionalPointLight()
        this.light.castShadow = true
        this.light.decay = 2
        this.light.shadow.mapSize.width = 1024
        this.light.shadow.mapSize.height = 1024
        this.light.shadow.camera.near = 10
        this.light.shadow.camera.far = 2000

        this.threeObject = new THREE.Group()
        this.threeObject.add(this.light)

        this.threeHelperObject = scene.config.editMode ? new ThreeProxyLightHelper(this) : null
    }

    update(light: SceneNodes.AreaLight) {
        let didUpdateMaterial = false
        if (light.intensity !== this._intensity || light.on !== this._on || !colorEqual(light.color, this._color)) {
            this._intensity = light.intensity
            this._on = light.on
            this._color = [...light.color]
            this.threeHelperObject?.update()
            didUpdateMaterial = true
        }

        if (light.width !== this._width || light.height !== this._height || light.directionality !== this._directionality) {
            this._width = light.width
            this._height = light.height
            this._directionality = light.directionality
            if (this.threeHelperObject) {
                this.threeHelperObject.updateMesh()
                this.threeHelperObject.update()
                this.scene.update()
            }
            didUpdateMaterial = true
        }

        if (!light.target.equals(this._target)) {
            if (!(light.target instanceof Vector3)) throw new Error("Expected target to be a Vector3")
            this._target = light.target.copy()
            this.updatePositionAndTarget()
        }

        if (didUpdateMaterial) {
            this.updateMaterial()
        }

        if (!(light.transform instanceof Matrix4)) throw new Error("Expected transform to be a Matrix4")

        this.updateTransform(light.transform)
        this.topLevelObjectId = light.topLevelObjectId
    }

    override updateTransform(transform: Matrix4): boolean {
        if (super.updateTransform(transform)) {
            this.updatePositionAndTarget()
            return true
        } else {
            return false
        }
    }

    override updateSamplingFrame(frame: number, maxFrame: number) {
        if (frame === 0) {
            this.light!.position.x = 0
            this.light!.position.y = 0
        } else {
            this.light!.position.x = (Math.random() - 0.5) * this._width
            this.light!.position.y = (Math.random() - 0.5) * this._height
        }
        this.light!.updateMatrix()
    }

    private updateMaterial() {
        //TODO: coalesce updates
        const intensity = this._intensity
        const on = this._on

        if (this.light) {
            this.light.visible = on
            this.light.color.setRGB(...this._color)
            this.light.intensity = GlobalRenderConstants.lightIntensityScale * (intensity * this._width * this._height)
        }

        if (this.threeHelperObject) {
            this.threeHelperObject.updateMaterial()
        }

        this.scene.update()
    }

    override getOutlineTokens(materialSlot: number | null): any[] {
        if (this.threeHelperObject) {
            return [this.threeHelperObject.localMesh]
        } else {
            return []
        }
    }

    override showEditHelpers(show: boolean) {
        this.threeHelperObject?.showEditHelpers(show)
    }

    private updatePositionAndTarget() {
        const direction = this._target.sub(fromThreeVector(this.threeObject.position)).normalized()
        this.light!.direction.set(direction.x, direction.y, direction.z)
        this.light!.needsUpdate = true
        if (this.threeHelperObject) {
            this.threeHelperObject.update()
        }
        this.scene.update()
    }
}

export class DirectionalPointLight extends THREE.PointLight {
    direction = new THREE.Vector3(0, 0, 1)
    needsUpdate = true

    constructor() {
        super()
    }
}

const lightDirectionList: THREE.Vector3[] = []

function applyDirectionalPointLightShaderFixes(): void {
    THREE.ShaderChunk.lights_physical_pars_fragment += `
    #if ( NUM_POINT_LIGHTS > 0 )
        uniform vec3 pointLightDirections[NUM_POINT_LIGHTS];
        #endif`

    const searchString = "getPointLightInfo( pointLight, geometryPosition, directLight );"

    const originalShader = THREE.ShaderChunk.lights_fragment_begin
    const index = THREE.ShaderChunk.lights_fragment_begin.indexOf(searchString)
    if (index === -1) throw new Error("Could not find string to replace in lights_fragment_begin")
    const patchPosition = index + searchString.length

    const patchedShader =
        originalShader.slice(0, patchPosition) +
        `
    directLight.color *= clamp( -dot(pointLightDirections[ i ], directLight.direction), 0., 1. );
    ` +
        originalShader.slice(patchPosition)

    THREE.ShaderChunk.lights_fragment_begin = patchedShader

    THREE.Material.prototype.onBeforeCompile = function (shader: THREE.WebGLProgramParametersWithUniforms, renderer: THREE.WebGLRenderer) {
        shader.uniforms.pointLightDirections = {value: lightDirectionList}
    }
}

applyDirectionalPointLightShaderFixes()

export function updateLightDirectionsUniform(threeScene: THREE.Scene, threeCamera: THREE.Camera): void {
    let lightIdx = 0
    threeScene.traverse((obj) => {
        if (obj instanceof DirectionalPointLight) {
            if (lightIdx >= lightDirectionList.length) {
                lightDirectionList.length = lightIdx
            }
            lightDirectionList[lightIdx] = obj.direction.clone().transformDirection(threeCamera.matrixWorldInverse)
            ++lightIdx
        } else if (obj instanceof THREE.PointLight) throw Error("Only DirectionalPointLight supported")
    })
    if (lightIdx < lightDirectionList.length) {
        lightDirectionList.length = lightIdx
    }
}

class ThreeProxyLightHelper extends THREE.Object3D {
    private light: ThreeAreaLight

    private globalMesh: THREE.Group
    localMesh: THREE.Group

    private targetCubeGeometry: THREE.BufferGeometry
    targetCube: THREE.Object3D
    private materialTargetCube1 = new THREE.MeshStandardMaterial({color: 0xff0000})
    private materialTargetCube2 = new THREE.MeshStandardMaterial({
        color: 0xff0000,
        transparent: true,
        opacity: 0.2,
        depthTest: false,
        depthWrite: false,
    })

    private bufferPositionStand: THREE.Float32BufferAttribute
    private geometryStand: THREE.BufferGeometry
    private linesStand: THREE.Line
    private materialStand = new THREE.LineBasicMaterial({color: 0xbbbbbb})

    private bufferPositionLaser: THREE.Float32BufferAttribute
    private geometryLaser: THREE.BufferGeometry
    private linesLaser: THREE.Line
    private materialLaser = new THREE.LineBasicMaterial({color: 0xee8888})

    private numRays = 5 // total number is a grid numRays * numRays
    private bufferPositionRays: THREE.Float32BufferAttribute
    private geometryRays: THREE.BufferGeometry
    private linesRays: THREE.Line
    private materialRays = new THREE.LineBasicMaterial({color: 0x666666})

    private frontMesh: THREE.Mesh
    private backMesh: THREE.Mesh

    private frontMaterial: THREE.MeshBasicMaterial
    private backMaterial: THREE.MeshBasicMaterial

    constructor(light: ThreeAreaLight) {
        super()

        this.light = light
        this.localMesh = new THREE.Group()
        this.globalMesh = new THREE.Group()

        this.frontMaterial = new THREE.MeshBasicMaterial()
        this.frontMesh = new THREE.Mesh(new THREE.PlaneGeometry(), this.frontMaterial)
        this.frontMesh.rotation.x = Math.PI
        this.frontMesh.userData.threeSceneObject = light
        this.frontMesh.layers.set(1)

        this.backMaterial = new THREE.MeshBasicMaterial({color: 0x080808})
        this.backMesh = new THREE.Mesh(new THREE.PlaneGeometry(), this.backMaterial)
        this.backMesh.userData.threeSceneObject = light
        this.backMesh.layers.set(1)

        this.targetCubeGeometry = new THREE.BoxGeometry(5, 5, 5)
        const targetMesh1 = new THREE.Mesh(this.targetCubeGeometry, this.materialTargetCube1)
        const targetMesh2 = new THREE.Mesh(this.targetCubeGeometry, this.materialTargetCube2)
        this.targetCube = new THREE.Group()
        this.targetCube.add(targetMesh1)
        this.targetCube.add(targetMesh2)
        targetMesh1.userData.threeSceneObject = light
        targetMesh1.userData.materialSlot = 1
        targetMesh1.layers.set(1)
        targetMesh2.userData.threeSceneObject = light
        targetMesh2.userData.materialSlot = 1
        targetMesh2.layers.set(1)
        this.globalMesh.add(this.targetCube)

        this.geometryStand = new THREE.BufferGeometry()
        this.bufferPositionStand = new THREE.Float32BufferAttribute(6, 3)
        this.geometryStand.setAttribute("position", this.bufferPositionStand)
        this.linesStand = new THREE.Line(this.geometryStand, this.materialStand)
        this.linesStand.layers.set(1)
        this.linesStand.userData.threeSceneObject = light
        this.globalMesh.add(this.linesStand)

        this.geometryLaser = new THREE.BufferGeometry()
        this.bufferPositionLaser = new THREE.Float32BufferAttribute(6, 3)
        this.geometryLaser.setAttribute("position", this.bufferPositionLaser)
        this.linesLaser = new THREE.Line(this.geometryLaser, this.materialLaser)
        this.linesLaser.layers.set(1)
        this.linesLaser.userData.threeSceneObject = light
        this.localMesh.add(this.linesLaser)

        this.geometryRays = new THREE.BufferGeometry()
        this.bufferPositionRays = new THREE.Float32BufferAttribute(2 * 3 * this.numRays * this.numRays, 3)
        this.geometryRays.setAttribute("position", this.bufferPositionRays)
        this.linesRays = new THREE.LineSegments(this.geometryRays, this.materialRays)
        this.linesRays.userData.threeSceneObject = light
        this.linesRays.layers.set(1)
        this.localMesh.add(this.linesRays)

        this.globalMesh.layers.set(1)

        this.add(this.globalMesh)
        this.add(this.localMesh)
        this.localMesh.add(this.frontMesh)
        this.localMesh.add(this.backMesh)

        this.showEditHelpers(false)
        this.update()
    }

    update(): void {
        this.localMesh.matrix = this.light.threeObject.matrix
        this.localMesh.matrix.decompose(this.localMesh.position, this.localMesh.quaternion, this.localMesh.scale)
        this.localMesh.matrixWorldNeedsUpdate = true

        const position = this.localMesh.position
        this.bufferPositionStand.set([position.x, 0, position.z, position.x, position.y, position.z])
        this.bufferPositionStand.needsUpdate = true

        this.targetCube.matrix.makeTranslation(this.light._target.x, this.light._target.y, this.light._target.z)
        this.targetCube.matrix.decompose(this.targetCube.position, this.targetCube.quaternion, this.targetCube.scale)
        this.targetCube.matrixWorldNeedsUpdate = true

        this.bufferPositionLaser.setZ(1, -toThreeVector(this.light._target).sub(position).length())
        this.bufferPositionLaser.needsUpdate = true

        // for (let ix = 0; ix < this.numRays; ix++) {
        //     for (let iy = 0; iy < this.numRays; iy++) {
        //         const tx = ix / (this.numRays - 1) * this.light._width * 0.9 - 0.9 * this.light._width / 2;
        //         const ty = iy / (this.numRays - 1) * this.light._height * 0.9 - 0.9 * this.light._height / 2;
        //         const fx = (1.0 - this.light._directionality) * tx / this.light._width;
        //         const fy = (1.0 - this.light._directionality) * ty / this.light._height;
        //         const fz = 0.5;
        //         const fac = Math.sqrt(fx ** 2 + fy ** 2 + fz ** 2) / (2 * this.light._intensity);
        //         this.bufferPositionRays.setXYZ(2 * (iy * this.numRays + ix), tx, ty, 0);
        //         this.bufferPositionRays.setXYZ(2 * (iy * this.numRays + ix) + 1, tx + fx / fac, ty + fy / fac, -fz / fac);
        //     }
        // }

        const directionality = (Math.min(1.0, Math.max(0.0, this.light._directionality)) * Math.PI) / 2.0
        const edgefxy = Math.cos(directionality)
        const edgefz = Math.sin(directionality)
        const centerfz = 1.0
        const width = this.light._width
        const height = this.light._height
        const centerRayIdx = this.numRays % 2 === 1.0 ? Math.floor(this.numRays / 2.0) : -1
        let fx: number
        let fy: number
        let fz: number

        for (let ix = 0; ix < this.numRays; ix++) {
            for (let iy = 0; iy < this.numRays; iy++) {
                const tx = (ix / (this.numRays - 1) - 0.5) * width * 0.9
                const ty = (iy / (this.numRays - 1) - 0.5) * height * 0.9

                fx = 0.0
                fy = 0.0
                fz = 1.0

                if (!(iy === centerRayIdx && ix === centerRayIdx)) {
                    const norm = Math.sqrt(tx ** 2 + ty ** 2)
                    const scaling = (width / 2.0 / Math.abs(tx)) * Math.abs(ty) < height / 2.0 ? (2 * Math.abs(tx)) / width : (2.0 * Math.abs(ty)) / height
                    fx = (tx / norm) * scaling * edgefxy
                    fy = (ty / norm) * scaling * edgefxy
                    fz = scaling * edgefz + (1.0 - scaling) * centerfz
                }

                const fac = Math.sqrt(fx ** 2 + fy ** 2 + fz ** 2) / (2 * this.light._intensity)
                this.bufferPositionRays.setXYZ(2 * (iy * this.numRays + ix), tx, ty, 0)
                this.bufferPositionRays.setXYZ(2 * (iy * this.numRays + ix) + 1, tx + fx / fac, ty + fy / fac, -fz / fac)
            }
        }

        this.bufferPositionRays.needsUpdate = true
    }

    updateMaterial() {
        this.frontMaterial.opacity = 1.0
        this.frontMaterial.transparent = false
        this.backMaterial.opacity = 1.0
        this.backMaterial.transparent = false

        const color = this.light._color
        this.frontMaterial.color.setRGB(color[0], color[1], color[2])
    }

    updateMesh() {
        //TODO: coalesce updates
        const width = this.light._width
        const height = this.light._height
        this.frontMesh.scale.x = width
        this.frontMesh.scale.y = height
        this.backMesh.scale.x = width
        this.backMesh.scale.y = height
        this.frontMesh.updateMatrix()
        this.backMesh.updateMatrix()
    }

    showEditHelpers(enable: boolean) {
        this.targetCube.visible = enable
        this.linesLaser.visible = enable
    }
}
