import {
    DeclareTemplateNode,
    DeclareTemplateNodeTS,
    TemplateNodeImplementation,
    TemplateNodeMeta,
    TemplateNodeTSImplementation,
} from "#template-nodes/declare-template-node"
import {Variable} from "#template-nodes/node-types"
import {R1Variable, S1Variable, S3Variable} from "#template-nodes/nodes/variable"
import {BuilderOutlet} from "#template-nodes/runtime-graph/graph-builder"
import {GraphBuilderScope} from "#template-nodes/runtime-graph/graph-builder-scope"
import {SolverRelationData} from "#template-nodes/runtime-graph/nodes/solver/relation-data"
import {SolverVariableData} from "#template-nodes/runtime-graph/nodes/solver/variable-data"
import {ThisStructID} from "#template-nodes/runtime-graph/types"
import {OnCompileContext, TemplateNode} from "#template-nodes/types"
import {nodeInstance} from "@cm/graph/instance"
import {NodeGraphClass, NodeParameters, nodeParameters} from "@cm/graph/node-graph"
import {registerNode} from "@cm/graph/register-node"
import {z} from "zod"

@registerNode
export class Translation extends DeclareTemplateNode(
    {
        parameters: z.object({
            x: z.number().or(nodeInstance(R1Variable)),
            y: z.number().or(nodeInstance(R1Variable)),
            z: z.number().or(nodeInstance(R1Variable)),
        }),
    },
    {},
    {nodeClass: "Translation"},
) {}

@registerNode
export class FixedRotation extends DeclareTemplateNode(
    {
        parameters: z.object({x: z.number(), y: z.number(), z: z.number()}),
    },
    {},
    {nodeClass: "FixedRotation"},
) {}

@registerNode
export class HingeRotation extends DeclareTemplateNode(
    {
        parameters: z.object({axis: z.enum(["x", "y", "z"]), rotation: nodeInstance(S1Variable)}),
    },
    {},
    {nodeClass: "HingeRotation"},
) {}

@registerNode
export class BallRotation extends DeclareTemplateNode(
    {
        parameters: z.object({rotation: nodeInstance(S3Variable)}),
    },
    {},
    {nodeClass: "BallRotation"},
) {}

const transformNode = z.object({
    translation: nodeInstance(Translation),
    rotation: nodeInstance(FixedRotation).or(nodeInstance(HingeRotation)).or(nodeInstance(BallRotation)),
})
type TransformNode = z.infer<typeof transformNode>

export function DeclareTransformNode<ZodParamTypes extends z.ZodType<NodeParameters>>(
    definition: {
        parameters: ZodParamTypes
    },
    implementation: TemplateNodeImplementation<z.infer<typeof definition.parameters> & TransformNode>,
    meta: TemplateNodeMeta<z.infer<typeof definition.parameters> & TransformNode>,
) {
    const {parameters: paramsSchema} = definition
    type ParamTypes = z.infer<typeof paramsSchema>

    return DeclareTransformNodeTS<ParamTypes>({...implementation, validation: {paramsSchema}}, meta)
}

export function DeclareTransformNodeTS<ParamTypes extends NodeParameters>(
    implementation: TemplateNodeTSImplementation<ParamTypes & TransformNode>,
    meta: TemplateNodeMeta<ParamTypes & TransformNode>,
): NodeGraphClass<TemplateTransformNode<ParamTypes>> {
    const retClass = class extends DeclareTemplateNodeTS<ParamTypes & TransformNode>(
        {...implementation, validation: {paramsSchema: transformNode.and(implementation.validation?.paramsSchema ?? nodeParameters)}},
        meta,
    ) {
        setupVariable(scope: GraphBuilderScope, context: OnCompileContext, variable: Variable) {
            const {currentTemplate} = context
            const {solverData} = currentTemplate
            const {solverVariables} = solverData
            const getTopology = () => {
                if (variable instanceof R1Variable) return "boundedReal"
                else if (variable instanceof S1Variable) return "1-sphere"
                else if (variable instanceof S3Variable) return "3-sphere"
                else throw Error("Invalid variable type")
            }

            const solverVariable = scope.struct<SolverVariableData>("SolverVariableData", {
                id: ThisStructID,
                topology: getTopology(),
                default: variable.parameters.default,
                range: variable instanceof R1Variable ? variable.parameters.range : undefined,
            })

            solverVariables.push(solverVariable)

            return solverVariable
        }

        setupTranslation(scope: GraphBuilderScope, context: OnCompileContext, translation: Translation) {
            const {x, y, z} = translation.parameters

            return scope.struct<SolverRelationData["translation"]>("RelationTranslation", {
                tx: typeof x === "number" ? x : this.setupVariable(scope, context, x),
                ty: typeof y === "number" ? y : this.setupVariable(scope, context, y),
                tz: typeof z === "number" ? z : this.setupVariable(scope, context, z),
            })
        }

        setupRotation(scope: GraphBuilderScope, context: OnCompileContext, rotation: FixedRotation | HingeRotation | BallRotation) {
            if (rotation instanceof HingeRotation) {
                const {axis, rotation: rot} = rotation.parameters
                return scope.struct<SolverRelationData["rotation"]>("RelationRotation_hinge", {
                    type: "hinge" as const,
                    axis: axis,
                    angleVariable: this.setupVariable(scope, context, rot),
                })
            } else if (rotation instanceof BallRotation) {
                const {rotation: rot} = rotation.parameters
                return scope.struct<SolverRelationData["rotation"]>("RelationRotation_ball", {
                    type: "ball" as const,
                    angleVariable: this.setupVariable(scope, context, rot),
                })
            } else if (rotation instanceof FixedRotation) {
                const {x, y, z} = rotation.parameters
                return scope.struct<SolverRelationData["rotation"]>("RelationRotation_fixed", {
                    type: "fixed" as const,
                    rx: x,
                    ry: y,
                    rz: z,
                })
            } else throw Error("Invalid rotation type")
        }
    }
    return retClass
}

export type TemplateTransformNode<ParamTypes extends NodeParameters = {}> = TemplateNode<ParamTypes & TransformNode> & {
    setupVariable(scope: GraphBuilderScope, context: OnCompileContext, variable: Variable): BuilderOutlet<SolverVariableData>
    setupTranslation(scope: GraphBuilderScope, context: OnCompileContext, translation: Translation): BuilderOutlet<SolverRelationData["translation"]>
    setupRotation(
        scope: GraphBuilderScope,
        context: OnCompileContext,
        rotation: FixedRotation | HingeRotation | BallRotation,
    ): BuilderOutlet<SolverRelationData["rotation"]>
}
