import {assertNever, castToFloat32Array, castToUint16Array, castToUint8Array, deepCopy, floatToHalfArray, halfToFloatArray} from "@cm/utils"
import {HalImage} from "@common/models/hal/hal-image"
import {
    HalImageChannelLayout,
    HalImageDataType,
    HalImageDescriptor,
    HalImageHTMLCanvasElement,
    HalImageHTMLImageElement,
    HalImageOptions,
    HalImageRawDataBufferType,
    HalImageRawDataType,
    HalImageSource,
    HalImageUrl,
    SomeHalImageRawDataBuffer,
} from "@common/models/hal/hal-image/types"
import {WebGl2Context} from "@common/models/webgl2/webgl2-context"
import {isHalImageDescriptor, isHalImageHTMLCanvasElement, isHalImageHTMLImageElement, isHalImageUrl} from "@common/helpers/hal"
import {getNumChannels} from "@common/models/hal/hal-image/utils"
import {linearToSrgb, srgbToLinear} from "@cm/image-processing/tone-mapping"

const TRACE = false

// - setting this to true fails on firefox for a 12000^2 image (but works on chrome)
// - when setting this to true and using a 22000^2 image, it fails on both firefox and chrome
// - on chrome for a 12000^2 image, it takes 6.2s (false) and 1.8s (true) to upload the image to the gpu
const USE_FASTER_BUT_LARGER_CANVAS = false

export function completeHalImageOptions(options?: Partial<HalImageOptions>): HalImageOptions {
    return {
        useMipMaps: options?.useMipMaps ?? false,
        useSRgbFormat: options?.useSRgbFormat ?? false,
    }
}

export class WebGl2Image implements HalImage {
    constructor(
        readonly context: WebGl2Context,
        source?: HalImageSource,
    ) {
        if (source) {
            void this.create(source)
        }
    }

    // HalImage
    async create(source: HalImageSource) {
        if (isHalImageHTMLImageElement(source)) {
            this.loadFromImage(source)
        } else if (isHalImageHTMLCanvasElement(source)) {
            this.loadFromCanvas(source)
        } else if (isHalImageDescriptor(source)) {
            this.createTexture(source)
        } else if (isHalImageUrl(source)) {
            await this.loadFromUrl(source)
        } else {
            throw Error("Invalid image reference")
        }
    }

    // HalEntity
    dispose(): void {
        this.releaseTexture()
    }

    // HalImage
    get descriptor(): HalImageDescriptor {
        return this._descriptor
    }

    get texture(): WebGLTexture {
        if (!this._texture) {
            throw Error("Texture is null. Did you forget to load the image ?")
        }
        return this._texture
    }

    get shardWidth(): number {
        return this._shardWidth
    }

    get shardHeight(): number {
        return this._shardHeight
    }

    get numShardsX(): number {
        return this._numShardsX
    }

    get numShardsY(): number {
        return this._numShardsY
    }

    get numShards(): number {
        return this._numShardsX * this._numShardsY
    }

    // HalImage
    async readRawImageData<T extends HalImageRawDataType>(rawDataType: T): Promise<HalImageRawDataBufferType<T>> {
        const numChannels = getNumChannels(this.descriptor.channelLayout)
        const numElements = this.descriptor.width * this.descriptor.height * numChannels
        const rawImageData = this.createBufferForFormat(this.descriptor.dataType, numElements)
        const format = this.getGlFormat()
        const type = this.getGlType()
        const gl = this.context.gl
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
        const readFbo = gl.createFramebuffer()
        gl.bindFramebuffer(gl.READ_FRAMEBUFFER, readFbo)
        gl.pixelStorei(gl.PACK_ALIGNMENT, 1) // make sure to tightly pack the data
        const shardImageData =
            this.numShards === 1 ? rawImageData : this.createBufferForFormat(this.descriptor.dataType, this._shardWidth * this._shardHeight * numChannels)
        for (let sy = 0; sy < this._numShardsY; sy++) {
            for (let sx = 0; sx < this._numShardsX; sx++) {
                const shardIndex = sy * this._numShardsX + sx
                gl.framebufferTextureLayer(gl.READ_FRAMEBUFFER, gl.COLOR_ATTACHMENT0, this._texture, 0, shardIndex)
                gl.readPixels(0, 0, this.shardWidth, this.shardHeight, format, type, shardImageData)
                const lastError = gl.getError()
                if (lastError !== gl.NO_ERROR) {
                    throw Error(`Failed to read pixels (${lastError}).`)
                }
                if (shardImageData !== rawImageData) {
                    // copy to final image data
                    for (let y = 0; y < this._shardHeight; y++) {
                        for (let x = 0; x < this._shardWidth; x++) {
                            const shardPixelIndex = y * this._shardWidth + x
                            const rawPixelIndex = (sy * this._shardHeight + y) * this.descriptor.width + sx * this._shardWidth + x
                            for (let channel = 0; channel < numChannels; channel++) {
                                rawImageData[rawPixelIndex * numChannels + channel] = shardImageData[shardPixelIndex * numChannels + channel]
                            }
                        }
                    }
                }
            }
        }
        const isSrgbFormat = this.descriptor.options?.useSRgbFormat ?? false
        return this.convertRawData(this.descriptor.dataType, isSrgbFormat, rawDataType, isSrgbFormat, rawImageData) as HalImageRawDataBufferType<T>
    }

    // HalImage
    async writeRawImageData<T extends HalImageRawDataType>(rawDataType: T, rawImageData: HalImageRawDataBufferType<T>): Promise<void> {
        const isSrgbFormat = this.descriptor.options?.useSRgbFormat ?? false
        const convertedImageData = this.convertRawData(
            rawDataType,
            isSrgbFormat,
            this.descriptor.dataType,
            isSrgbFormat,
            rawImageData,
        ) as HalImageRawDataBufferType<T>
        const gl = this.context.gl
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
        const format = this.getGlFormat()
        const type = this.getGlType()
        gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1) // make sure to tightly pack the data
        const numChannels = getNumChannels(this.descriptor.channelLayout)
        const shardImageData =
            this.numShards === 1 ? convertedImageData : this.createBufferForFormat(this.descriptor.dataType, this._shardWidth * this._shardHeight * numChannels)
        for (let sy = 0; sy < this._numShardsY; sy++) {
            for (let sx = 0; sx < this._numShardsX; sx++) {
                const shardIndex = sy * this._numShardsX + sx
                if (shardImageData !== convertedImageData) {
                    // copy to shard image data
                    for (let y = 0; y < this._shardHeight; y++) {
                        for (let x = 0; x < this._shardWidth; x++) {
                            const shardPixelIndex = y * this._shardWidth + x
                            const rawPixelIndex = (sy * this._shardHeight + y) * this.descriptor.width + sx * this._shardWidth + x
                            for (let channel = 0; channel < numChannels; channel++) {
                                shardImageData[shardPixelIndex * numChannels + channel] = convertedImageData[rawPixelIndex * numChannels + channel]
                            }
                        }
                    }
                }
                gl.texSubImage3D(gl.TEXTURE_2D_ARRAY, 0, 0, 0, shardIndex, this._shardWidth, this._shardHeight, 1, format, type, shardImageData)
            }
        }
        this.generateMipmaps()
    }

    private convertRawData(from: HalImageDataType, fromSrgb: boolean, to: HalImageDataType, toSrgb: boolean, data: SomeHalImageRawDataBuffer) {
        switch (from) {
            case "uint8":
                switch (to) {
                    case "uint8":
                        if (fromSrgb !== toSrgb) {
                            throw new Error("Cannot convert between sRGB and linear formats. Not implemented yet.")
                        }
                        return castToUint8Array(data)
                    case "float16":
                        return floatToHalfArray(this.uint8ToFloatArray(data as Uint8Array, fromSrgb))
                    case "float32":
                        return this.uint8ToFloatArray(data as Uint8Array, fromSrgb)
                    default:
                        throw new Error(`Unexpected "to" for uint8 conversion: ${to}`)
                }
            case "float16":
                switch (to) {
                    case "uint8":
                        return this.floatToUint8Array(halfToFloatArray(data as Uint16Array), toSrgb)
                    case "float16":
                        return castToUint16Array(data)
                    case "float32":
                        return halfToFloatArray(data)
                    default:
                        throw new Error(`Unexpected "to" for float16 conversion: ${to}`)
                }
            case "float32":
                switch (to) {
                    case "uint8":
                        return this.floatToUint8Array(data as Float32Array, toSrgb)
                    case "float16":
                        return floatToHalfArray(data)
                    case "float32":
                        return castToFloat32Array(data)
                    default:
                        throw new Error(`Unexpected "to" for float32 conversion: ${to}`)
                }
            default:
                assertNever(from)
        }
    }

    private floatToUint8Array(data: Float32Array, isSrgbFormat: boolean): Uint8Array {
        const uint8Data = new Uint8Array(data.length)
        for (let i = 0; i < data.length; i++) {
            let value = data[i]
            if (isSrgbFormat) {
                value = linearToSrgb(value)
            }
            uint8Data[i] = Math.max(0, Math.min(255, Math.round(value * 255)))
        }
        return uint8Data
    }

    private uint8ToFloatArray(data: Uint8Array, isSrgbFormat: boolean): Float32Array {
        const floatData = new Float32Array(data.length)
        for (let i = 0; i < data.length; i++) {
            let value = data[i] / 255
            if (isSrgbFormat) {
                value = srgbToLinear(value)
            }
            floatData[i] = value
        }
        return floatData
    }

    // the texture needs to be bound to TEXTURE_2D_ARRAY before calling this
    generateMipmaps() {
        if (this._numLevels > 1) {
            const gl = this.context.gl
            gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
            gl.generateMipmap(gl.TEXTURE_2D_ARRAY)
        }
    }

    private async loadFromUrl(source: HalImageUrl): Promise<void> {
        const htmlImageElement = new Image()
        htmlImageElement.crossOrigin = "anonymous"
        return new Promise((resolve) => {
            htmlImageElement.onload = () => {
                this.loadFromImage({htmlImageElement, options: source.options})
                resolve()
            }
            htmlImageElement.onerror = (error) => {
                throw Error("Failed to load image: " + error)
            }
            htmlImageElement.src = source.url
        })
    }

    private loadFromImage(source: HalImageHTMLImageElement) {
        // create texture
        const descriptor: HalImageDescriptor = {
            width: source.htmlImageElement.width,
            height: source.htmlImageElement.height,
            channelLayout: "RGBA",
            dataType: "uint8",
            options: source.options,
        }
        this.createTexture(descriptor)

        const start = performance.now()

        // use canvas to get the pixel data array of the image
        const canvas = document.createElement("canvas")
        if (USE_FASTER_BUT_LARGER_CANVAS) {
            canvas.width = this._shardWidth * this._numShardsX
            canvas.height = this._shardHeight * this._numShardsY
        } else {
            canvas.width = this._shardWidth
            canvas.height = this._shardHeight
        }
        const ctx = canvas.getContext("2d", {willReadFrequently: true})
        if (!ctx) {
            throw Error("Failed to get 2d context from canvas.")
        }
        if (USE_FASTER_BUT_LARGER_CANVAS) {
            ctx.drawImage(source.htmlImageElement, 0, 0)
            // TODO consider copying the image next to each other (in both dimensions) to avoid interpolation issues when sampling linearly
        }

        // copy to texture
        const gl = this.context.gl
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
        gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1) // make sure to tightly pack the data
        for (let sy = 0; sy < this._numShardsY; sy++) {
            for (let sx = 0; sx < this._numShardsX; sx++) {
                const shardIndex = sy * this._numShardsX + sx
                let imageData: ImageData
                if (USE_FASTER_BUT_LARGER_CANVAS) {
                    imageData = ctx.getImageData(sx * this._shardWidth, sy * this._shardHeight, this._shardWidth, this._shardHeight)
                } else {
                    ctx.drawImage(
                        source.htmlImageElement,
                        sx * this._shardWidth,
                        sy * this._shardHeight,
                        this._shardWidth,
                        this._shardHeight,
                        0,
                        0,
                        this._shardWidth,
                        this._shardHeight,
                    )
                    if (sx === this._numShardsX - 1 || sy === this._numShardsY - 1) {
                        // for the last shards, we copy the start of the image to the end of the shard to avoid interpolation issues when sampling linearly
                        // TODO
                    }
                    imageData = ctx.getImageData(0, 0, this._shardWidth, this._shardHeight)
                }
                gl.texSubImage3D(gl.TEXTURE_2D_ARRAY, 0, 0, 0, shardIndex, this._shardWidth, this._shardHeight, 1, gl.RGBA, gl.UNSIGNED_BYTE, imageData.data)
            }
        }
        this.generateMipmaps()

        const end = performance.now()
        if (TRACE) {
            console.log("Uploaded image to GPU in " + (end - start) + "ms")
        }
    }

    private loadFromCanvas(source: HalImageHTMLCanvasElement): void {
        // create texture
        const descriptor: HalImageDescriptor = {
            width: source.htmlCanvasElement.width,
            height: source.htmlCanvasElement.height,
            channelLayout: "RGBA",
            dataType: "uint8",
            options: source.options,
        }
        this.createTexture(descriptor)

        const start = performance.now()

        const ctx = source.htmlCanvasElement.getContext("2d", {willReadFrequently: true})
        if (!ctx) {
            throw Error("Failed to get 2d context from canvas.")
        }

        // copy to texture
        const gl = this.context.gl
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
        gl.pixelStorei(gl.UNPACK_ALIGNMENT, 1) // make sure to tightly pack the data
        for (let sy = 0; sy < this._numShardsY; sy++) {
            for (let sx = 0; sx < this._numShardsX; sx++) {
                const shardIndex = sy * this._numShardsX + sx
                const imageData = ctx.getImageData(sx * this._shardWidth, sy * this._shardHeight, this._shardWidth, this._shardHeight)
                gl.texSubImage3D(gl.TEXTURE_2D_ARRAY, 0, 0, 0, shardIndex, this._shardWidth, this._shardHeight, 1, gl.RGBA, gl.UNSIGNED_BYTE, imageData.data)
            }
        }
        this.generateMipmaps()

        const end = performance.now()
        if (TRACE) {
            console.log("Uploaded image to GPU in " + (end - start) + "ms")
        }
    }

    private createTexture(descriptor: HalImageDescriptor): void {
        const descriptorOptions = completeHalImageOptions(descriptor.options)
        const currentOptions = completeHalImageOptions(this._descriptor.options)
        if (
            descriptor.width === this.descriptor.width &&
            descriptor.height === this.descriptor.height &&
            descriptor.channelLayout === this.descriptor.channelLayout &&
            descriptor.dataType === this.descriptor.dataType &&
            descriptorOptions.useSRgbFormat === currentOptions.useSRgbFormat &&
            descriptorOptions.useMipMaps === currentOptions.useMipMaps
        ) {
            return
        }

        this.releaseTexture()

        this._descriptor = deepCopy(descriptor)
        if (this._descriptor.width === 0 || this._descriptor.height === 0) {
            return
        }
        if (this._descriptor.width < 0 || this._descriptor.height < 0) {
            throw Error("Image dimensions must be positive.")
        }
        if (!Number.isInteger(this._descriptor.width) || !Number.isInteger(this._descriptor.height)) {
            throw Error("Image dimensions must be integers.")
        }

        switch (this._descriptor.dataType) {
            case "float16":
                if (!this.context.EXT_color_buffer_half_float) {
                    if (this.context.EXT_color_buffer_float) {
                        console.warn("Device does not support float16 format. Falling back to float32.")
                        this._descriptor.dataType = "float32"
                    } else {
                        throw new Error("Float format not supported by device")
                    }
                }
                break
            case "float32":
                if (!this.context.EXT_color_buffer_float) {
                    if (this.context.EXT_color_buffer_half_float) {
                        console.warn("Device does not support float32 format. Falling back to float16.")
                        this._descriptor.dataType = "float16"
                    } else {
                        throw new Error("Float format not supported by device")
                    }
                }
                break
        }

        const start = performance.now()

        this._shardWidth = this.computeOptimalShardSize(this._descriptor.width)
        this._shardHeight = this.computeOptimalShardSize(this._descriptor.height)
        this._numLevels = descriptorOptions.useMipMaps ? Math.floor(Math.log2(Math.max(this._shardWidth, this._shardHeight))) + 1 : 1
        this._numShardsX = Math.ceil(this._descriptor.width / this._shardWidth)
        this._numShardsY = Math.ceil(this._descriptor.height / this._shardHeight)
        const numShards = this._numShardsX * this._numShardsY

        if (TRACE) {
            console.log(
                `Creating sharded WebGL texture with ${numShards} shards of size ${this._shardWidth}x${this._shardHeight} containing ${this._numLevels} mipmap levels for image of size ${descriptor.width}x${descriptor.height}.`,
            )
        }

        const gl = this.context.gl
        this._texture = gl.createTexture()
        gl.bindTexture(gl.TEXTURE_2D_ARRAY, this._texture)
        gl.texStorage3D(
            gl.TEXTURE_2D_ARRAY,
            this._numLevels,
            this.getInternalFormat(this._descriptor.channelLayout, this._descriptor.dataType, descriptorOptions.useSRgbFormat),
            this._shardWidth,
            this._shardHeight,
            numShards,
        )
        const lastError = gl.getError()
        if (lastError !== gl.NO_ERROR) {
            if (lastError === gl.OUT_OF_MEMORY) {
                throw new OutOfMemoryError("Failed to create texture: Out of memory.")
            } else {
                throw Error(`Failed to create texture (${lastError}).`)
            }
        }

        const end = performance.now()
        if (TRACE) {
            console.log("Created GPU image in " + (end - start) + "ms")
        }
    }

    private getGlFormat(): number {
        const gl = this.context.gl
        switch (this.descriptor.channelLayout) {
            case "RGBA":
                return gl.RGBA
                break
            case "RGB":
                return gl.RGB

            case "R":
                return gl.RED
            default:
                assertNever(this.descriptor.channelLayout)
        }
    }

    private getGlType(): number {
        const gl = this.context.gl
        switch (this.descriptor.dataType) {
            case "uint8":
                return gl.UNSIGNED_BYTE
            case "float16":
                return gl.HALF_FLOAT
            case "float32":
                return gl.FLOAT
            default:
                assertNever(this.descriptor.dataType)
        }
    }

    private createBufferForFormat<T extends HalImageDataType>(rawDataType: T, numElements: number): FormatBufferType<T> {
        switch (rawDataType) {
            case "uint8":
                return new Uint8Array(numElements) as FormatBufferType<T> // TODO Why is the type assertion necessary here?
            case "float16":
                return new Uint16Array(numElements) as FormatBufferType<T> // TODO Why is the type assertion necessary here?
            case "float32":
                return new Float32Array(numElements) as FormatBufferType<T> // TODO Why is the type assertion necessary here?
            default:
                assertNever(rawDataType)
        }
    }

    private createBufferForRawData<T extends HalImageRawDataType>(rawDataType: T, numElements: number): HalImageRawDataBufferType<T> {
        switch (rawDataType) {
            case "uint8":
                return new Uint8Array(numElements) as HalImageRawDataBufferType<T> // TODO Why is the type assertion necessary here?
            case "float16":
                return new Uint16Array(numElements) as HalImageRawDataBufferType<T> // TODO Why is the type assertion necessary here?
            case "float32":
                return new Float32Array(numElements) as HalImageRawDataBufferType<T> // TODO Why is the type assertion necessary here?
            default:
                assertNever(rawDataType)
        }
    }

    private computeOptimalShardSize(length: number): number {
        // we want as little shards as possible, but we also don't want to waste too much unused space in the last shard
        const maxShardSize = this.context.maxTextureSize
        const maxNumShards = this.context.maxTextureLayers
        const numRequiredShards = Math.ceil(length / maxShardSize)
        if (numRequiredShards > maxNumShards) {
            throw Error(`Image would require more shards (${numRequiredShards}) than the GPU allows (${maxNumShards}).`)
        }
        return Math.ceil(length / numRequiredShards)
    }

    private getInternalFormat(channelLayout: HalImageChannelLayout, format: HalImageDataType, sRGB: boolean): number {
        const gl = this.context.gl
        switch (format) {
            case "uint8":
                switch (channelLayout) {
                    case "RGBA":
                        return sRGB ? gl.SRGB8_ALPHA8 : gl.RGBA8
                    case "RGB":
                        return sRGB ? gl.SRGB8 : gl.RGB8
                    case "R":
                        if (sRGB) {
                            throw Error("sRGB is not supported for R layout.")
                        }
                        return gl.R8
                    default:
                        throw new Error(`Unexpected channel layout for uint8: ${channelLayout}`)
                }
            case "float32":
                switch (channelLayout) {
                    case "RGBA":
                        return gl.RGBA32F
                    case "RGB":
                        return gl.RGB32F
                    case "R":
                        return gl.R32F
                    default:
                        throw new Error(`Unexpected channel layout for float32: ${channelLayout}`)
                }
            case "float16":
                switch (channelLayout) {
                    case "RGBA":
                        return gl.RGBA16F
                    case "RGB":
                        return gl.RGB16F
                    case "R":
                        return gl.R16F
                    default:
                        throw new Error(`Unexpected channel layout for float16: ${channelLayout}`)
                }
            default:
                throw new Error(`Unexpected format: ${format}`)
        }
    }

    private releaseTexture() {
        if (this._texture) {
            const gl = this.context.gl
            gl.deleteTexture(this._texture)
            this._texture = null
        }
        this._descriptor = {
            width: 0,
            height: 0,
            channelLayout: "RGBA",
            dataType: "uint8",
        }
        this._numLevels = 0
        this._shardWidth = 0
        this._shardHeight = 0
        this._numShardsX = 0
        this._numShardsY = 0
    }

    private _descriptor: HalImageDescriptor = {
        width: 0,
        height: 0,
        channelLayout: "RGBA",
        dataType: "uint8",
    }
    private _numLevels = 0
    private _shardWidth = 0
    private _shardHeight = 0
    private _numShardsX = 0
    private _numShardsY = 0
    private _texture: WebGLTexture | null = null
}

export type FormatBufferType<T extends HalImageDataType> = T extends "uint8" ? Uint8Array : T extends "float16" ? Uint16Array : Float32Array

export class OutOfMemoryError extends Error {
    constructor(message: string) {
        super(message)
        this.name = "OutOfMemoryError"
    }
}
