import {z} from "zod"
import {ZodUtils} from "@cm/utils/zod-utils"

export namespace RenderNodes {
    export const RenderNodeBaseSchema = z.object({})
    export type RenderNodeBase = z.infer<typeof RenderNodeBaseSchema>

    export const DataObjectReferenceSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("dataObjectReference"),
            dataObjectId: z.number(),
        }),
    )
    export type DataObjectReference = z.infer<typeof DataObjectReferenceSchema>

    export const DataSchema = DataObjectReferenceSchema
    export type Data = z.infer<typeof DataSchema>

    export const LoadMeshSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("loadMesh"),
            data: DataSchema,
        }),
    )
    export type LoadMesh = z.infer<typeof LoadMeshSchema>

    export type Subdivide = RenderNodeBase & {
        type: "subdivide"
        input: MeshData
        levels: number
    }
    export const SubdivideSchema: z.ZodType<Subdivide> = z.lazy(() =>
        RenderNodeBaseSchema.and(
            z.object({
                type: z.literal("subdivide"),
                input: MeshDataSchema,
                levels: z.number(),
            }),
        ),
    )

    //TODO: maybe move MeshData nodes to geometry graph module?
    export type MeshDataToGeometry = {op: "meshToGeom"; args: [MeshData]}
    export const MeshDataToGeometrySchema: z.ZodType<MeshDataToGeometry> = z.lazy(() =>
        z.object({
            op: z.literal("meshToGeom"),
            args: z.tuple([MeshDataSchema]),
        }),
    )

    export type GeometryOperator = {op: string; args: (number | GeometryExpr)[]}
    export const GeometryOperatorSchema: z.ZodType<GeometryOperator> = z.lazy(() =>
        z.object({
            op: z.string(),
            args: z.union([z.number(), GeometryExprSchema]).array(),
        }),
    )

    export const GeometryExprSchema = z.union([GeometryOperatorSchema, MeshDataToGeometrySchema])
    export type GeometryExpr = z.infer<typeof GeometryExprSchema>

    export const GeomGraphSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("geomGraph"),
            graph: GeometryExprSchema,
        }),
    )
    export type GeomGraph = z.infer<typeof GeomGraphSchema>

    export const EmptyMeshSchema = z.object({
        type: z.literal("empty"),
    })
    export type EmptyMesh = z.infer<typeof EmptyMeshSchema>

    export const PlaneMeshSchema = z.object({
        type: z.literal("plane"),
        width: z.number(),
        height: z.number(),
        normalAxis: z.union([z.literal("x+"), z.literal("y+"), z.literal("z+"), z.literal("x-"), z.literal("y-"), z.literal("z-")]),
    })
    export type PlaneMesh = z.infer<typeof PlaneMeshSchema>

    export const MeshDataSchema = z.union([LoadMeshSchema, SubdivideSchema, GeomGraphSchema, EmptyMeshSchema, PlaneMeshSchema])
    export type MeshData = z.infer<typeof MeshDataSchema>

    export type ShaderNode = RenderNodeBase & {
        type: string
        inputs?: {[name: string]: readonly [ShaderNode, string]}
        parameters?: {[name: string]: Exclude<any, null | Array<any>> | Array<Exclude<any, null>>}
        resources?: {[name: string]: Image}
    }

    export const ShaderNodeSchema: z.ZodType<ShaderNode> = z.lazy(() =>
        RenderNodeBaseSchema.and(
            z.object({
                type: z.string(),
                inputs: z.record(z.tuple([ShaderNodeSchema, z.string()])).optional(),
                parameters: z
                    .record(
                        z.string(),
                        z
                            .any()
                            .refine((value) => value !== null, {message: "Shader node parameter value can not be null!"})
                            .refine((value) => (value instanceof Array ? !value.includes(null) : true), {
                                message: "Shader node parameter value can not be an array containing a null!",
                            }),
                    )
                    .optional(),
                resources: z.record(z.string(), ImageSchema).optional(),
            }),
        ),
    )

    export const CameraSchema = RenderNodeBaseSchema.and(
        z.object({
            transform: z.number().array(),
            focalLength: z.number(),
            focalDistance: z.number(),
            sensorSize: z.number(),
            fStop: z.number().optional(),
            exposure: z.number().optional(),
            shiftX: z.number().optional(),
            shiftY: z.number().optional(),
            nearClip: z.number().optional(),
            farClip: z.number().optional(),
        }),
    )
    export type Camera = z.infer<typeof CameraSchema>

    export const MeshSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("mesh").optional(), //Prior to schema version 1, the mesh type was not present in the schema
            id: z.string(),
            meshData: MeshDataSchema,
            transform: z.number().array(),
            shaders: z.record(ShaderNodeSchema),
            cryptomatteObjectName: z.string().optional(),
            cryptomatteAssetName: z.string().optional(),
            shadowCatcher: z.boolean().optional(),
            adaptiveSubdivisionRate: z.number().optional(),
            visibleInCamera: z.boolean().optional(),
            visibleInReflections: z.boolean().optional(),
            visibleInRefractions: z.boolean().optional(),
            visibleToShadowRays: z.boolean().optional(),
            visibleToGlossyRays: z.boolean().optional(),
            visibleToDiffuseRays: z.boolean().optional(),
        }),
    )
    export type Mesh = z.infer<typeof MeshSchema>

    export const InstancesSchema = z.discriminatedUnion("type", [
        z.object({
            type: z.literal("transforms"),
            transforms: z.array(z.number().array()),
        }),
        z.object({
            type: z.literal("curveIntersections"),
            meshData: MeshDataSchema,
            curvePoints: z.object({
                points: z.array(z.number()),
                normals: z.array(z.number()),
                tangents: z.array(z.number()),
                scales: z.array(z.number()),
            }),
        }),
    ])
    export type Instances = z.infer<typeof InstancesSchema>

    export const MeshInstancesSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("meshInstances"),
            id: z.string(),
            mesh: MeshSchema,
            transform: z.number().array(),
            instances: InstancesSchema,
        }),
    )
    export type MeshInstances = z.infer<typeof MeshInstancesSchema>

    export const ObjectSchema = z.union([MeshSchema, MeshInstancesSchema])
    export type Object = z.infer<typeof ObjectSchema>

    export const LoadImageSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("loadImage"),
            data: DataSchema,
        }),
    )
    export type LoadImage = z.infer<typeof LoadImageSchema>

    export const DistanceTextureSchema = z.object({
        type: z.literal("distanceTexture"),
        meshData: MeshDataSchema,
        transform: z.number().array(),
        range: z.number(),
        width: z.number(),
        height: z.number(),
        forceOriginalResolution: z.boolean().optional(),
        uvChannel: z.number(),
        targets: z.array(ObjectSchema),
        innerValue: z.number().optional(),
    })
    export type DistanceTexture = z.infer<typeof DistanceTextureSchema>

    export const ComputeImageSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("computeImage"),
            data: DistanceTextureSchema,
        }),
    )
    export type ComputeImage = z.infer<typeof ComputeImageSchema>

    export const ImageSchema = LoadImageSchema.or(ComputeImageSchema)
    export type Image = z.infer<typeof ImageSchema>

    export const LightPortalSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("portal"),
            id: z.string(),
            transform: z.number().array(),
            width: z.number(),
            height: z.number(),
            round: z.boolean().optional(),
        }),
    )
    export type LightPortal = z.infer<typeof LightPortalSchema>

    export const AreaLightSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("area"),
            id: z.string(),
            transform: z.number().array(),
            width: z.number(),
            height: z.number(),
            strength: z.number(),
            shader: ShaderNodeSchema.optional(),
            round: z.boolean().optional(),
            directionality: z.number().optional(),
            visibleInCamera: z.boolean().optional(),
            visibleInReflections: z.boolean().optional(),
            visibleInRefractions: z.boolean().optional(),
            visibleToShadowRays: z.boolean().optional(),
            visibleToGlossyRays: z.boolean().optional(),
            visibleToDiffuseRays: z.boolean().optional(),
        }),
    )
    export type AreaLight = z.infer<typeof AreaLightSchema>

    export const LightSchema = z.union([LightPortalSchema, AreaLightSchema])
    export type Light = z.infer<typeof LightSchema>

    export const SceneSchema = RenderNodeBaseSchema.and(
        z.object({
            camera: CameraSchema,
            objects: ObjectSchema.array(),
            lights: LightSchema.array(),
            environment: ShaderNodeSchema.optional(),
        }),
    )
    export type Scene = z.infer<typeof SceneSchema>

    export const PassNameSchema = z.union([
        z.literal("Combined"),
        z.literal("Depth"),
        z.literal("Mist"),
        z.literal("Roughness"),
        z.literal("SampleCount"),
        z.literal("Position"),
        z.literal("Normal"),
        z.literal("ObjectID"),
        z.literal("UV"),
        z.literal("Motion"),
        z.literal("MaterialID"),
        z.literal("Diffuse"),
        z.literal("Glossy"),
        z.literal("Volume"),
        z.literal("DiffuseDirect"),
        z.literal("GlossyDirect"),
        z.literal("TransmissionDirect"),
        z.literal("VolumeDirect"),
        z.literal("DiffuseIndirect"),
        z.literal("GlossyIndirect"),
        z.literal("TransmissionIndirect"),
        z.literal("VolumeIndirect"),
        z.literal("DiffuseColor"),
        z.literal("GlossyColor"),
        z.literal("TransmissionColor"),
        z.literal("Emission"),
        z.literal("Background"),
        z.literal("AmbientOcclusion"),
        z.literal("Shadow"),
        ZodUtils.templateLiteral(z.literal("CryptoObject"), z.number()),
        ZodUtils.templateLiteral(z.literal("CryptoMaterial"), z.number()),
        ZodUtils.templateLiteral(z.literal("CryptoAsset"), z.number()),
        z.literal("NoisyCombined"),
        z.literal("DenoisingNormal"),
        z.literal("DenoisingAlbedo"),
        z.literal("DenoisingDepth"),
        ZodUtils.templateLiteral(z.literal("AOVColor:"), z.string()),
        ZodUtils.templateLiteral(z.literal("AOVValue:"), z.string()),
        z.literal("ShadowCatcher"),
        z.literal("ShadowCatcherSampleCount"),
        z.literal("ShadowCatcherMatte"),
    ])
    export type PassName = z.infer<typeof PassNameSchema>

    export const SessionSchema = RenderNodeBaseSchema.and(
        z.object({
            options: z
                .object({
                    threads: z.number().optional(),
                    gpu: z.boolean().optional(),
                    transparent_background: z.boolean().optional(),
                    transparent_glass: z.boolean().optional(),
                    final_render: z.boolean().optional(),
                    use_denoising: z.boolean().optional(),
                    adaptive_subdivision_offscreen_dicing_scale: z.number().optional(),
                    passes: PassNameSchema.array().optional(),
                    previewPass: PassNameSchema.optional(),
                })
                .and(z.record(z.any())),
        }),
    )
    export type Session = z.infer<typeof SessionSchema>

    export const RenderSchema = RenderNodeBaseSchema.and(
        z.object({
            type: z.literal("render"),
            schemaVersion: z.number().optional(), //Missing values are considered to be version 0
            session: SessionSchema,
            scene: SceneSchema,
            width: z.number(),
            height: z.number(),
            samples: z.number(),
        }),
    )
    export type Render = z.infer<typeof RenderSchema>

    export const NodeSchema = z.union([
        PassNameSchema,
        SessionSchema,
        RenderSchema,
        SceneSchema,
        CameraSchema,
        ObjectSchema,
        LightSchema,
        DataSchema,
        ImageSchema,
        MeshDataSchema,
        ShaderNodeSchema,
    ])
    export type Node = z.infer<typeof NodeSchema>

    export function dataObjectReference(legacyId: number): DataObjectReference {
        return {
            type: "dataObjectReference",
            dataObjectId: legacyId,
        }
    }

    export function distanceTexture(data: Omit<DistanceTexture, "type">): DistanceTexture {
        return {
            type: "distanceTexture",
            ...data,
        }
    }

    export function computeImage(data: ComputeImage["data"]): ComputeImage {
        return {
            type: "computeImage",
            data,
        }
    }

    export function loadImage(legacyId: number): LoadImage {
        return {
            type: "loadImage",
            data: dataObjectReference(legacyId),
        }
    }
}

export function isCryptomattePass(passName: RenderNodes.PassName | string) {
    return passName.startsWith("Crypto")
}

export function isAOVPass(passName: RenderNodes.PassName | string) {
    return passName.startsWith("AOV")
}

function findNodes(graph: RenderNodes.Node, predicate: (node: RenderNodes.Node) => boolean): RenderNodes.Node[] {
    const nodes: RenderNodes.Node[] = []
    const visited = new Set()

    const traverse = (node: RenderNodes.Node) => {
        if (typeof node !== "object") return
        if (Array.isArray(node) || node instanceof Array) {
            for (const val of Object.values(node)) traverse(val)
        } else if (ArrayBuffer.isView(node) || node instanceof ArrayBuffer || node instanceof Blob) {
            return
        } else {
            if (visited.has(node)) return
            visited.add(node)
            if (predicate(node)) nodes.push(node)
            for (const value of Object.values(node)) traverse(value)
        }
    }

    traverse(graph)
    return nodes
}

type LegacyId = number

export function gatherDataObjectReferences(graph: RenderNodes.Node): LegacyId[] {
    const nodes = findNodes(graph, (n: RenderNodes.Node) => typeof n === "object" && "type" in n && n["type"] === "dataObjectReference")
    return Array.from(new Set(nodes.map((n) => (n as {dataObjectId: number})["dataObjectId"])))
}
