import {ImageProcessingNodes as Nodes} from "@cm/lib/image-processing/image-processing-nodes"
import {ImageOpType} from "app/textures/texture-editor/operator-stack/image-op-system/detail/types"
import {ImagePtr} from "app/textures/texture-editor/operator-stack/image-op-system/image-ref"
import {assertSameSize, getChannelLayoutByCount} from "app/textures/texture-editor/operator-stack/image-op-system/detail/utils"
import {getImgProcChannelLayout, toImgProcResultImage} from "app/textures/texture-editor/operator-stack/image-op-system/detail/utils-img-proc"
import {ImagePtrArrayWebGl2} from "@app/textures/texture-editor/operator-stack/image-op-system/image-webgl2"
import {ImagePtrArrayImgProc} from "@app/textures/texture-editor/operator-stack/image-op-system/image-imgproc"

export type ParameterType = {
    sourceImages: [ImagePtr, ImagePtr, ImagePtr] | [ImagePtr, ImagePtr, ImagePtr, ImagePtr]
    resultImage?: ImagePtr
}

export type ReturnType = ImagePtr

export const imageOpCombineChannels: ImageOpType<ParameterType, ReturnType> = {
    name: "CombineChannels",

    WebGL2: async ({context, parameters: {sourceImages, resultImage}}) => {
        using sourceImagesWebGl2 = new ImagePtrArrayWebGl2(await Promise.all(sourceImages.map((image) => context.getImage(image))))
        const primarySourceImageWebGL2 = sourceImagesWebGl2[0]
        for (const sourceImageWebGL2 of sourceImagesWebGl2) {
            assertSameSize(primarySourceImageWebGL2.ref.descriptor, sourceImageWebGL2.ref.descriptor)
        }
        const numChannels = sourceImagesWebGl2.length
        const halExtractChannel = await context.getOrCreateImageCompositor(`
            vec4 computeColor(ivec2 targetPixel) {
                float r = texelFetch0(targetPixel).r;
                float g = ${numChannels >= 2 ? "texelFetch1(targetPixel).r" : "0.0"};
                float b = ${numChannels >= 3 ? "texelFetch2(targetPixel).r" : "0.0"};
                float a = ${numChannels >= 4 ? "texelFetch3(targetPixel).r" : "1.0"};
                return vec4(r, g, b, a);
            }
        `)
        resultImage = await context.prepareResultImage(resultImage, {
            width: primarySourceImageWebGL2.ref.descriptor.width,
            height: primarySourceImageWebGL2.ref.descriptor.height,
            channelLayout: getChannelLayoutByCount(numChannels),
            format: primarySourceImageWebGL2.ref.descriptor.format,
            isSRGB: primarySourceImageWebGL2.ref.descriptor.isSRGB,
        })
        using resultImageWebGl2 = await context.getImage(resultImage)
        await halExtractChannel.paint(
            resultImageWebGl2.ref.halImage,
            sourceImagesWebGl2.map((image) => image.ref.halImage),
        )
        return resultImage
    },

    ImgProc: async ({context, parameters: {sourceImages, resultImage}}) => {
        using sourceImagesImgProc = new ImagePtrArrayImgProc(await Promise.all(sourceImages.map((image) => context.getImage(image))))
        const primarySourceImage = sourceImagesImgProc[0]
        for (const sourceImageImgProc of sourceImagesImgProc) {
            assertSameSize(primarySourceImage.ref.descriptor, sourceImageImgProc.ref.descriptor)
        }
        const channelLayout = getChannelLayoutByCount(sourceImages.length)
        const imgProcChannelLayout = getImgProcChannelLayout(channelLayout)
        if (imgProcChannelLayout === "L") {
            throw Error("ImgProc's CombineChannels does not support single channel images")
        }
        const resultNode = Nodes.combineChannels(
            sourceImagesImgProc.map((image) => image.ref.node) as
                | [Nodes.ImageNode, Nodes.ImageNode, Nodes.ImageNode]
                | [Nodes.ImageNode, Nodes.ImageNode, Nodes.ImageNode, Nodes.ImageNode],
            imgProcChannelLayout,
        )
        using result = await context.createImage(
            {
                width: primarySourceImage.ref.descriptor.width,
                height: primarySourceImage.ref.descriptor.height,
                channelLayout: channelLayout,
                format: primarySourceImage.ref.descriptor.format,
                isSRGB: primarySourceImage.ref.descriptor.isSRGB,
            },
            resultNode as Nodes.ImageNode,
        )
        return await toImgProcResultImage(context, result, resultImage)
    },
}
