import {HalImageChannelLayout} from "@common/models/hal/hal-image/types"
import {HalPainterParameterValueType} from "@common/models/hal/hal-painter/types"
import {Matrix4} from "@common/helpers/vector-math"
import {Matrix3x2, Matrix3x2Like} from "@cm/lib/math/matrix3x2"
import {assertNever} from "@cm/lib/utils/utils"
import {WebGl2Context} from "@common/models/webgl2/webgl2-context"
import {WebGl2Image} from "@common/models/webgl2/webgl2-image"
import * as WebGl2ShaderUtils from "@common/helpers/webgl2/webgl2-shader-utils"
import {ADDRESS_MODE_BORDER, ADDRESS_MODE_CLAMP_TO_EDGE, ADDRESS_MODE_REPEAT, MAX_TEXTURE_UNITS} from "@common/helpers/webgl2/constants"
import {ParameterValue} from "@common/helpers/webgl2/types"
import {HalPaintable} from "@common/models/hal/hal-paintable"

/**
 * This class is shared between multiple WebGlLayerGeometry instances that use the same WebGlCanvas and shading-function.
 * The shading-function is a GLSL function with the signature "vec4 computeColor(vec2 worldPosition, vec2 uv, vec4 color)" and
 * can call functions with the signature "vec4 texelFetchN(uvec2 texelIndex)" to sample from bound textures, where N
 * ranges from 0 to MAX_TEXTURE_UNITS - 1 and texelIndex is in pixels.
 */

export class WebGl2Shader {
    constructor(
        readonly context: WebGl2Context,
        shadingFunction: string,
    ) {
        const vertexSrc = `#version 300 es
            precision highp float;
            precision highp int;

            layout(location = ${this.LOC_POSITION}) in vec2 a_position;
            layout(location = ${this.LOC_UV}) in vec2 a_uv;
            layout(location = ${this.LOC_COLOR}) in vec4 a_color;

            uniform uvec2 u_targetSize;
            uniform mat3x2 u_worldTransform;
            uniform mat3x2 u_viewTransform;

            out vec2 v_worldPosition;
            out vec2 v_uv;
            out vec4 v_color;
                        
            void main() {
                v_worldPosition = u_worldTransform * vec3(a_position, 1);
                v_uv = a_uv;
                v_color = a_color;
                gl_Position = vec4(u_viewTransform * vec3(v_worldPosition, 1), 0.0, 1.0);
            }
        `
        const texture = (index: number): string => `
            vec4 texture${index}(ivec2 texelIndex, int addressMode) {
                if (isBorderTexel(${index}, texelIndex, addressMode)) {
                    return borderColor;
                }
                texelIndex = applyAddressMode(${index}, texelIndex, addressMode);
                vec2 shardUV = uvByTexelIndex(${index}, texelIndex);
                int shardIndex = shardIndexByTexelIndex(${index}, texelIndex);
                vec4 texColor = texture(u_image[${index}], vec3(shardUV, shardIndex));
                texColor = fillMissingChannels(${index}, texColor);
                return texColor;
            }

            vec4 texture${index}(ivec2 texelIndex) {
                return texture${index}(texelIndex, ADDRESS_MODE_CLAMP_TO_EDGE);
            }
        `
        const textureNormalized = (index: number): string => `
            vec4 textureNormalized${index}(vec2 uv, int addressMode) {
                ivec2 texelIndex = ivec2(uv * vec2(u_imageSize[${index}]));
                return texture${index}(texelIndex, addressMode);
            }

            vec4 textureNormalized${index}(vec2 uv) {
                return textureNormalized${index}(uv, ADDRESS_MODE_CLAMP_TO_EDGE);
            }
        `
        const texelFetchLod = (index: number): string => `
            vec4 texelFetchLod${index}(ivec2 texelIndex, int lod, int addressMode) {
                if (isBorderTexel(${index}, texelIndex, addressMode)) {
                    return borderColor;
                }
                texelIndex = applyAddressMode(${index}, texelIndex, addressMode);
                ivec2 shardTexel = shardTexelByTexelIndex(${index}, texelIndex);
                int shardIndex = shardIndexByTexelIndex(${index}, texelIndex);
                vec4 texColor = texelFetch(u_image[${index}], ivec3(shardTexel, shardIndex), lod);
                texColor = fillMissingChannels(${index}, texColor);
                return texColor;
            }
         `
        const texelFetch = (index: number): string => `
            vec4 texelFetch${index}(ivec2 texelIndex, int addressMode) {
                return texelFetchLod${index}(texelIndex, 0, addressMode);
            }
            
            vec4 texelFetch${index}(ivec2 texelIndex) {
                return texelFetchLod${index}(texelIndex, 0, ADDRESS_MODE_REPEAT);
            }
        `
        const texelFetchLodInterpolated = (index: number): string => `
            vec4 texelFetchLodInterpolated${index}(vec2 texelIndex, int lod, int addressMode) {
                texelIndex -= 0.5;  // we offset by 0.5 to sample at the center of the texel instead of the corner
                ivec2 texelIndexI = ivec2(floor(texelIndex));
                vec2 texelIndexF = texelIndex - vec2(texelIndexI);
                vec4 texColor00 = texelFetchLod${index}(texelIndexI + ivec2(0, 0), lod, addressMode);
                vec4 texColor10 = texelFetchLod${index}(texelIndexI + ivec2(1, 0), lod, addressMode);
                vec4 texColor01 = texelFetchLod${index}(texelIndexI + ivec2(0, 1), lod, addressMode);
                vec4 texColor11 = texelFetchLod${index}(texelIndexI + ivec2(1, 1), lod, addressMode);
                vec4 texColor = mix(mix(texColor00, texColor10, texelIndexF.x), mix(texColor01, texColor11, texelIndexF.x), texelIndexF.y);
                return texColor;
            }

            vec4 texelFetchLodInterpolated${index}(vec2 texelIndex, int lod) {
                return texelFetchLodInterpolated${index}(texelIndex, lod, ADDRESS_MODE_REPEAT);
            }
        `
        const texelFetchInterpolated = (index: number): string => `
            vec4 texelFetchInterpolated${index}(vec2 texelIndex, int addressMode) {
                return texelFetchLodInterpolated${index}(texelIndex, 0, addressMode);
            }

            vec4 texelFetchInterpolated${index}(vec2 texelIndex) {
                return texelFetchInterpolated${index}(texelIndex, ADDRESS_MODE_REPEAT);
            }
        `
        const fragmentSrc = `#version 300 es
            precision highp float;
            precision highp int;
            precision highp sampler2DArray;

            layout(location = 0) out vec4 color;
            
            uniform uvec2 u_targetSize;
            uniform sampler2DArray u_image[${MAX_TEXTURE_UNITS}];
            uniform uvec2 u_imageSize[${MAX_TEXTURE_UNITS}];
            uniform uvec2 u_shardSize[${MAX_TEXTURE_UNITS}];
            uniform uvec2 u_numShards[${MAX_TEXTURE_UNITS}];
            uniform int u_numChannels[${MAX_TEXTURE_UNITS}];
            uniform vec4 u_modulationColor;

            in vec2 v_worldPosition;
            in vec2 v_uv;
            in vec4 v_color;

            const vec4 borderColor = vec4(0, 0, 0, 0);

            const int ADDRESS_MODE_CLAMP_TO_EDGE = ${ADDRESS_MODE_CLAMP_TO_EDGE};
            const int ADDRESS_MODE_REPEAT = ${ADDRESS_MODE_REPEAT};
            const int ADDRESS_MODE_BORDER = ${ADDRESS_MODE_BORDER};

            int wrapInt(int value, int maxValue) {
                value = value % maxValue;
                return value >= 0 ? value : value + maxValue;
            }

            ivec2 wrapTexelIndex(int index, ivec2 texelIndex) {
                ivec2 imageSize = ivec2(u_imageSize[index]);
                texelIndex.x = wrapInt(texelIndex.x, imageSize.x);
                texelIndex.y = wrapInt(texelIndex.y, imageSize.y);
                return texelIndex;
            }
                                    
            ivec2 wrapBatchedTexelIndex(int index, ivec2 targetPixel, ivec2 batchSize) {
                ivec2 imageSize = ivec2(u_imageSize[index]);
                ivec2 patchSize = imageSize / batchSize;
                ivec2 batchIndex = targetPixel / patchSize;
                batchIndex.x = wrapInt(batchIndex.x, batchSize.x);
                batchIndex.y = wrapInt(batchIndex.y, batchSize.y);
                ivec2 texelIndex = targetPixel % patchSize;
                texelIndex.x = wrapInt(texelIndex.x, patchSize.x);
                texelIndex.y = wrapInt(texelIndex.y, patchSize.y);
                return batchIndex * patchSize + texelIndex;
            }

            vec2 uvByTexelIndex(int index, ivec2 texelIndex) {
                texelIndex = wrapTexelIndex(index, texelIndex);
                ivec2 shardTexelIndex = texelIndex % ivec2(u_shardSize[index]);
                return (vec2(shardTexelIndex) + 0.5) / vec2(u_shardSize[index]);
            }

            bool isBorderTexel(int index, ivec2 texelIndex, int addressMode) {
                switch (addressMode) {
                    case ADDRESS_MODE_BORDER:
                        return texelIndex.x < 0 
                            || texelIndex.y < 0 
                            || texelIndex.x >= int(u_imageSize[index].x) 
                            || texelIndex.y >= int(u_imageSize[index].y);
                    default:
                        return false;
                }
            }

            bool isBorderTexel(int index, ivec2 texelIndex) {
                return isBorderTexel(index, texelIndex, ADDRESS_MODE_BORDER);
            }

            ivec2 applyAddressMode(int index, ivec2 texelIndex, int addressMode) {
                switch (addressMode) {
                    case ADDRESS_MODE_CLAMP_TO_EDGE:
                    case ADDRESS_MODE_BORDER:
                        return clamp(texelIndex, ivec2(0), ivec2(u_imageSize[index]) - 1);
                    case ADDRESS_MODE_REPEAT:
                    default:
                        return wrapTexelIndex(index, texelIndex);
                }
            }

            ivec2 shardTexelByTexelIndex(int index, ivec2 texelIndex) {
                ivec2 shardTexelIndex = texelIndex % ivec2(u_shardSize[index]);
                return shardTexelIndex;
            }

            int shardIndexByTexelIndex(int index, ivec2 texelIndex) {
                ivec2 shardIndex = texelIndex / ivec2(u_shardSize[index]);
                return shardIndex.y * int(u_numShards[index].x) + shardIndex.x;
            }

            vec4 fillMissingChannels(int index, vec4 color) {
                switch (u_numChannels[index]) {
                    case 1:
                        return vec4(color.r, color.r, color.r, 1);
                    case 2:
                        return vec4(color.r, color.g, 0, 1);
                    case 3:
                        return vec4(color.r, color.g, color.b, 1);
                    case 4:
                    default:
                        return color;
                }
            }

            // NOTE: We provide separate functions for each texture unit, because if we make the index a parameter we will get "ERROR: array index for samplers must be constant integral expressions" for "texture(u_image[index], ...)" even when it is a constant from the caller's point of view. :(
            ${texture(0)}
            ${texture(1)}
            ${texture(2)}
            ${texture(3)}
            ${textureNormalized(0)}
            ${textureNormalized(1)}
            ${textureNormalized(2)}
            ${textureNormalized(3)}
            ${texelFetchLod(0)}
            ${texelFetchLod(1)}
            ${texelFetchLod(2)}
            ${texelFetchLod(3)}
            ${texelFetch(0)}
            ${texelFetch(1)}
            ${texelFetch(2)}
            ${texelFetch(3)}
            ${texelFetchLodInterpolated(0)}
            ${texelFetchLodInterpolated(1)}
            ${texelFetchLodInterpolated(2)}
            ${texelFetchLodInterpolated(3)}
            ${texelFetchInterpolated(0)}
            ${texelFetchInterpolated(1)}
            ${texelFetchInterpolated(2)}
            ${texelFetchInterpolated(3)}

            ${shadingFunction}
            
            void main() {
                vec4 computedColor = computeColor(v_worldPosition, v_uv, v_color);
                color = computedColor * u_modulationColor;
            }
        `

        const gl = this.context.gl
        this.program = WebGl2ShaderUtils.compileAndLinkProgram(gl, vertexSrc, fragmentSrc)
        this.locWorldTransform = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_worldTransform")
        this.locViewTransform = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_viewTransform")
        this.locTargetSize = this.getOptionalUniformLocation("u_targetSize")
        this.locImage = this.getOptionalUniformLocation("u_image")
        this.locImageSize = this.getOptionalUniformLocation("u_imageSize")
        this.locShardSize = this.getOptionalUniformLocation("u_shardSize")
        this.locNumShards = this.getOptionalUniformLocation("u_numShards")
        this.locNumChannels = this.getOptionalUniformLocation("u_numChannels")
        this.locModulationColor = WebGl2ShaderUtils.getUniformLocation(gl, this.program, "u_modulationColor")
    }

    // WebGlEntity
    dispose(): void {
        this.context.gl.deleteProgram(this.program)
    }

    getUniformLocation(uniformName: string): WebGLUniformLocation {
        return WebGl2ShaderUtils.getUniformLocation(this.context.gl, this.program, uniformName)
    }

    getOptionalUniformLocation(uniformName: string): WebGLUniformLocation | null {
        try {
            return WebGl2ShaderUtils.getUniformLocation(this.context.gl, this.program, uniformName)
        } catch (e) {
            return null
        }
    }

    setUniforms(
        target: HalPaintable,
        worldTransform: Matrix3x2Like,
        images: (WebGl2Image | undefined)[],
        modulationColor: [r: number, g: number, b: number, a: number],
    ) {
        const targetSize = [target.width, target.height]
        const imageSizes = images.flatMap((image) => (image ? [image.descriptor.width, image.descriptor.height] : [0, 0]))
        const imageShardSizes = images.flatMap((image) => (image ? [image.shardWidth, image.shardHeight] : [0, 0]))
        const imageNumShards = images.flatMap((image) => (image ? [image.numShardsX, image.numShardsY] : [0, 0]))
        const getNumChannelByChannelLayout = (channelLayout: HalImageChannelLayout) => {
            switch (channelLayout) {
                case "RGBA":
                    return 4
                case "RGB":
                    return 3
                case "R":
                    return 1
                default:
                    assertNever(channelLayout)
            }
        }
        const gl = this.context.gl
        gl.uniformMatrix3x2fv(this.locWorldTransform, false, Matrix3x2.fromMatrix3x2Like(worldTransform).toArray())
        if (this.locTargetSize) {
            gl.uniform2uiv(this.locTargetSize, targetSize)
        }
        if (this.locImage) {
            gl.uniform1iv(this.locImage, [0, 1, 2, 3]) // this should be in sync with MAX_TEXTURE_UNITS above
        }
        if (this.locImageSize) {
            gl.uniform2uiv(this.locImageSize, imageSizes)
        }
        if (this.locShardSize) {
            gl.uniform2uiv(this.locShardSize, imageShardSizes)
        }
        if (this.locNumShards) {
            gl.uniform2uiv(this.locNumShards, imageNumShards)
        }
        if (this.locNumChannels) {
            gl.uniform1iv(
                this.locNumChannels,
                images.map((image) => (image ? getNumChannelByChannelLayout(image.descriptor.channelLayout) : 0)),
            )
        }
        gl.uniform4fv(this.locModulationColor, modulationColor)
        this.setParameters()
    }

    setViewTransform(viewTransform: number[]) {
        const gl = this.context.gl
        gl.uniformMatrix3x2fv(this.locViewTransform, false, viewTransform)
    }

    setProgramAndData(images: (WebGl2Image | undefined)[]) {
        const gl = this.context.gl
        gl.useProgram(this.program)
        for (let i = 0; i < MAX_TEXTURE_UNITS; i++) {
            const image = images[i]
            const texture = image && image.numShards ? image.texture : null
            gl.activeTexture(gl.TEXTURE0 + i)
            gl.bindTexture(gl.TEXTURE_2D_ARRAY, texture)
            if (texture) {
                gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MAG_FILTER, gl.NEAREST)
                gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_MIN_FILTER, gl.NEAREST_MIPMAP_NEAREST) // float textures require this filtering to be NEAREST
                gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)
                gl.texParameteri(gl.TEXTURE_2D_ARRAY, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)
            }
        }
    }

    unsetProgramAndData() {
        const gl = this.context.gl
        for (let i = 0; i < MAX_TEXTURE_UNITS; i++) {
            gl.activeTexture(gl.TEXTURE0 + i)
            gl.bindTexture(gl.TEXTURE_2D_ARRAY, null)
        }
    }

    setParameter(name: string, value: HalPainterParameterValueType, isOptional?: boolean): void {
        const location = this.getOptionalUniformLocation(name)
        if (location) {
            this.parameterValueByName.set(name, {location, valueType: value})
        } else if (!isOptional) {
            console.warn("Parameter not found: " + name)
        }
    }

    private setParameters() {
        const gl = this.context.gl
        this.parameterValueByName.forEach((parameterValue, _name) => {
            switch (parameterValue.valueType.type) {
                case "float": {
                    gl.uniform1f(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "float2": {
                    gl.uniform2f(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y)
                    break
                }
                case "float3": {
                    gl.uniform3f(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y, parameterValue.valueType.value.z)
                    break
                }
                case "float4": {
                    gl.uniform4f(
                        parameterValue.location,
                        parameterValue.valueType.value.x,
                        parameterValue.valueType.value.y,
                        parameterValue.valueType.value.z,
                        parameterValue.valueType.value.w,
                    )
                    break
                }
                case "int": {
                    gl.uniform1i(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "int2": {
                    gl.uniform2i(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y)
                    break
                }
                case "int3": {
                    gl.uniform3i(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y, parameterValue.valueType.value.z)
                    break
                }
                case "int4": {
                    gl.uniform4i(
                        parameterValue.location,
                        parameterValue.valueType.value.x,
                        parameterValue.valueType.value.y,
                        parameterValue.valueType.value.z,
                        parameterValue.valueType.value.w,
                    )
                    break
                }
                case "uint": {
                    gl.uniform1ui(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "uint2": {
                    gl.uniform2ui(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y)
                    break
                }
                case "uint3": {
                    gl.uniform3ui(parameterValue.location, parameterValue.valueType.value.x, parameterValue.valueType.value.y, parameterValue.valueType.value.z)
                    break
                }
                case "uint4": {
                    gl.uniform4ui(
                        parameterValue.location,
                        parameterValue.valueType.value.x,
                        parameterValue.valueType.value.y,
                        parameterValue.valueType.value.z,
                        parameterValue.valueType.value.w,
                    )
                    break
                }
                case "float[]": {
                    gl.uniform1fv(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "int[]": {
                    gl.uniform1iv(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "uint[]": {
                    gl.uniform1uiv(parameterValue.location, parameterValue.valueType.value)
                    break
                }
                case "float2[]": {
                    gl.uniform2fv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y]),
                    )
                    break
                }
                case "int2[]": {
                    gl.uniform2iv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y]),
                    )
                    break
                }
                case "uint2[]": {
                    gl.uniform2uiv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y]),
                    )
                    break
                }
                case "float3[]": {
                    gl.uniform3fv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z]),
                    )
                    break
                }
                case "int3[]": {
                    gl.uniform3iv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z]),
                    )
                    break
                }
                case "uint3[]": {
                    gl.uniform3uiv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z]),
                    )
                    break
                }
                case "float4[]": {
                    gl.uniform4fv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                    )
                    break
                }
                case "int4[]": {
                    gl.uniform4iv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                    )
                    break
                }
                case "uint4[]": {
                    gl.uniform4uiv(
                        parameterValue.location,
                        parameterValue.valueType.value.flatMap((v) => [v.x, v.y, v.z, v.w]),
                    )
                    break
                }
                case "float3x2": {
                    const matrixValues = Matrix3x2.fromMatrix3x2Like(parameterValue.valueType.value).toArray()
                    gl.uniformMatrix3x2fv(parameterValue.location, false, matrixValues)
                    break
                }
                case "float4x4": {
                    const matrixValues = Matrix4.fromMatrix4Like(parameterValue.valueType.value).toArray()
                    gl.uniformMatrix4fv(parameterValue.location, false, matrixValues)
                    break
                }
                default:
                    assertNever(parameterValue.valueType)
                    break
            }
        })
    }

    readonly LOC_POSITION = 0
    readonly LOC_UV = 1
    readonly LOC_COLOR = 2

    private program: WebGLProgram
    private locWorldTransform: WebGLUniformLocation
    private locViewTransform: WebGLUniformLocation
    private locTargetSize: WebGLUniformLocation | null
    private locImage: WebGLUniformLocation | null
    private locImageSize: WebGLUniformLocation | null
    private locShardSize: WebGLUniformLocation | null
    private locNumShards: WebGLUniformLocation | null
    private locNumChannels: WebGLUniformLocation | null
    private locModulationColor: WebGLUniformLocation
    private parameterValueByName = new Map<string, ParameterValue>()
}
