import {ImageProcessingNodes as Nodes} from "@cm/lib/image-processing/image-processing-nodes"
import {ImageOpType} from "app/textures/texture-editor/operator-stack/image-op-system/detail/types"
import {ImagePtr} from "app/textures/texture-editor/operator-stack/image-op-system/image-ref"
import {toImgProcResultImage} from "app/textures/texture-editor/operator-stack/image-op-system/detail/utils-img-proc"

export type ParameterType = {
    sourceImage: ImagePtr
    channel: "R" | "G" | "B" | "A"
    resultImage?: ImagePtr
}

export type ReturnType = ImagePtr

export const imageOpExtractChannel: ImageOpType<ParameterType, ReturnType> = {
    name: "ExtractChannel",

    WebGL2: async ({context, parameters: {sourceImage, channel, resultImage}}) => {
        const halExtractChannel = await context.getOrCreateImageCompositor(`
            vec4 computeColor(ivec2 targetPixel) {
                float r = texelFetch0(targetPixel).${channel.toLowerCase()};
                float g = 0.0;
                float b = 0.0;
                float a = 1.0;
                return vec4(r, g, b, a);
            }
        `)
        using sourceImageWebGl2 = await context.getImage(sourceImage)
        resultImage = await context.prepareResultImage(resultImage, {
            width: sourceImageWebGl2.ref.descriptor.width,
            height: sourceImageWebGl2.ref.descriptor.height,
            channelLayout: "R",
            format: sourceImageWebGl2.ref.descriptor.format,
            isSRGB: sourceImageWebGl2.ref.descriptor.isSRGB,
        })
        using resultImageWebGl2 = await context.getImage(resultImage)
        await halExtractChannel.paint(resultImageWebGl2.ref.halImage, sourceImageWebGl2.ref.halImage)
        return resultImage
    },

    ImgProc: async ({context, parameters: {sourceImage, channel, resultImage}}) => {
        const getChannelIndex = (channel: ParameterType["channel"]) => {
            switch (channel) {
                case "R":
                    return 0
                case "G":
                    return 1
                case "B":
                    return 2
                case "A":
                    return 3
                default:
                    throw new Error(`Invalid channel: ${channel}`)
            }
        }
        using sourceImageImgProc = await context.getImage(sourceImage)
        const resultNode = Nodes.extractChannel(sourceImageImgProc.ref.node, getChannelIndex(channel))
        using result = await context.createImage(
            {
                width: sourceImageImgProc.ref.descriptor.width,
                height: sourceImageImgProc.ref.descriptor.height,
                channelLayout: "R",
                format: sourceImageImgProc.ref.descriptor.format,
                isSRGB: sourceImageImgProc.ref.descriptor.isSRGB,
            },
            resultNode,
        )
        return await toImgProcResultImage(context, result, resultImage)
    },
}
