import {
    NodeParameters,
    NodeParamEvaluator,
    isNodeGraphInstance,
    NodeGraph,
    DeclareNodeGraphTS,
    NodeGraphMeta,
    GraphParameter,
    nodeGraph,
    nodeParameters,
} from "@src/graph-system/node-graph"
import * as THREENodes from "three/examples/jsm/nodes/Nodes"
import {Context, MaterialType, context} from "@src/materials/types"
import {RenderNodes} from "@src/rendering/render-nodes"
import {z} from "zod"
import * as THREE from "three"
import {mapFields} from "@src/utils/utils"

export const cyclesNode = RenderNodes.ShaderNodeSchema
export type CyclesNode = z.infer<typeof cyclesNode>

export const threeNode = z.instanceof(THREENodes.Node)
export type ThreeNode = z.infer<typeof threeNode>

export const materialSlots = z.union([threeNode, cyclesNode])
export type MaterialSlot = z.infer<typeof materialSlots>

export type MaterialInput = {[key: string]: MaterialSlot}
export const materialInput = z.record(materialSlots)
export type MaterialOutput = {[key: string]: MaterialSlot | THREE.Material}
export const materialOutput = z.record(materialSlots.or(z.instanceof(THREE.Material)))

type CycleInput<T extends MaterialSlot> = T extends CyclesNode ? T : never
type ThreeInput<T extends MaterialSlot> = T extends ThreeNode | THREE.Material ? T : never
type MaterialInputForType<T extends MaterialSlot, M extends MaterialType> = M extends "cycles" ? CycleInput<T> : M extends "three" ? ThreeInput<T> : never
type MaterialInputsForType<InputTypes extends MaterialInput, M extends MaterialType> = {
    [P in keyof InputTypes]: MaterialInputForType<InputTypes[P], M>
}

export type MaterialInputParameter<ReturnType> = GraphParameter<ReturnType, Context>
export type GetMaterialInputs<ParamTypes extends MaterialInput> = {
    [P in keyof ParamTypes]: MaterialInputParameter<ParamTypes[P]>
}
export type MaterialParameters<InputTypes extends MaterialInput, ParamTypes extends NodeParameters> = GetMaterialInputs<InputTypes> & {parameters: ParamTypes}

const inputParameterSchema = (value: z.ZodTypeAny) => {
    if (value.safeParse(undefined).success) return nodeGraph.optional()
    else return nodeGraph
}

const getInputsSchema = (paramsSchema: z.ZodType<MaterialInput>) => {
    const isZodObject = (schema: object): schema is z.ZodObject<z.ZodRawShape> => {
        return (schema as any).shape !== undefined
    }

    const isZodRecord = (schema: object): schema is z.ZodRecord => {
        return (schema as any).keySchema !== undefined && (schema as any).valueSchema !== undefined
    }

    if (isZodObject(paramsSchema)) return z.object(mapFields(paramsSchema.shape, (value) => inputParameterSchema(value)))
    else if (isZodRecord(paramsSchema)) return z.record(paramsSchema.keySchema, inputParameterSchema(paramsSchema.valueSchema))
    else throw Error("paramsSchema must be a ZodObject or ZodRecord")
}

export function DeclareMaterialNode<
    ZodReturnType extends z.ZodType<MaterialOutput>,
    ZodInputTypes extends z.ZodType<MaterialInput>,
    ZodParamTypes extends z.ZodType<NodeParameters>,
>(
    definition: {
        returns: ZodReturnType
        inputs: ZodInputTypes
        parameters: ZodParamTypes
    },
    implementation: {
        toCycles?: (data: {
            get: NodeParamEvaluator
            context: Context
            inputs: GetMaterialInputs<MaterialInputsForType<z.infer<typeof definition.inputs>, "cycles">>
            parameters: z.infer<typeof definition.parameters>
        }) => Promise<MaterialInputsForType<z.infer<typeof definition.returns>, "cycles">>
        toThree?: (data: {
            get: NodeParamEvaluator
            context: Context
            inputs: GetMaterialInputs<MaterialInputsForType<z.infer<typeof definition.inputs>, "three">>
            parameters: z.infer<typeof definition.parameters>
        }) => Promise<MaterialInputsForType<z.infer<typeof definition.returns>, "three">>
    },
    meta?: NodeGraphMeta<MaterialParameters<z.infer<typeof definition.inputs>, {parameters: z.infer<typeof definition.parameters>}>>,
) {
    const {returns: returnSchema, inputs: inputsSchema, parameters: paramsSchema} = definition
    type ReturnType = z.infer<typeof returnSchema>
    type InputTypes = z.infer<typeof inputsSchema>
    type ParamTypes = z.infer<typeof paramsSchema>

    return DeclareMaterialNodeTS<ReturnType, InputTypes, ParamTypes>({...implementation, validation: {returnSchema, inputsSchema, paramsSchema}}, meta)
}

export function DeclareMaterialNodeTS<ReturnType extends MaterialOutput, InputTypes extends MaterialInput, ParamTypes extends NodeParameters>(
    implementation: {
        toCycles?: (data: {
            get: NodeParamEvaluator
            context: Context
            inputs: GetMaterialInputs<MaterialInputsForType<InputTypes, "cycles">>
            parameters: ParamTypes
        }) => Promise<MaterialInputsForType<ReturnType, "cycles">>
        toThree?: (data: {
            get: NodeParamEvaluator
            context: Context
            inputs: GetMaterialInputs<MaterialInputsForType<InputTypes, "three">>
            parameters: ParamTypes
        }) => Promise<MaterialInputsForType<ReturnType, "three">>
        validation?: {
            returnSchema?: z.ZodType<MaterialOutput>
            inputsSchema?: z.ZodType<MaterialInput>
            paramsSchema?: z.ZodType<NodeParameters>
        }
    },
    meta?: NodeGraphMeta<MaterialParameters<InputTypes, ParamTypes>>,
) {
    const {validation} = implementation
    return DeclareNodeGraphTS<ReturnType, Context, MaterialParameters<InputTypes, ParamTypes>>(
        {
            run: function (this: MaterialNode<ReturnType, InputTypes, ParamTypes>, {get, parameters, context}) {
                const {parameters: currentParameters, ...inputs} = parameters
                switch (context.type) {
                    case "cycles":
                        if (!implementation.toCycles) throw new Error(`ToCycles method not implemented for ${this.getNodeClass()}`)
                        return implementation.toCycles.bind(this)({
                            get,
                            context,
                            inputs: inputs as NodeParameters as GetMaterialInputs<MaterialInputsForType<InputTypes, "cycles">>,
                            parameters: currentParameters,
                        })
                    case "three":
                        if (!implementation.toThree) throw new Error(`ToThree method not implemented for ${this.getNodeClass()}`)
                        return implementation.toThree.bind(this)({
                            get,
                            context,
                            inputs: inputs as NodeParameters as GetMaterialInputs<MaterialInputsForType<InputTypes, "three">>,
                            parameters: currentParameters,
                        })
                }
            },
            validation: {
                returnSchema: validation?.returnSchema ?? materialOutput,
                contextSchema: context,
                paramsSchema: getInputsSchema(validation?.inputsSchema ?? materialInput).and(
                    z.object({parameters: validation?.paramsSchema ?? nodeParameters}),
                ),
            },
        },
        meta,
    )
}

export type MaterialNode<ReturnType extends MaterialOutput = {}, InputTypes extends MaterialInput = {}, ParamTypes extends NodeParameters = {}> = NodeGraph<
    ReturnType,
    Context,
    MaterialParameters<InputTypes, ParamTypes>
>

export const isMaterialNodeNode = (instance: unknown): instance is MaterialNode => isNodeGraphInstance(instance)
export const materialNodeInstance = z.any().refine(isMaterialNodeNode, {message: "Expected material node"})
