import {MeshDataGraph} from "#template-nodes/geometry-processing/mesh-data"
import {RenderNodes} from "@cm/render-nodes"
type CompilableOpExpr = RenderNodes.GeometryOperator
type OpExpr = RenderNodes.GeometryExpr
type Expr = number | OpExpr

type AttributeType = "float" | "float2" | "float3" | "uint32"

function elemType(type: AttributeType): AttributeType | null {
    switch (type) {
        case "float":
            return null
        case "float2":
            return "float"
        case "float3":
            return "float"
        case "uint32":
            return null
    }
}

function isFloatType(type: AttributeType): boolean {
    switch (type) {
        case "float":
        case "float2":
        case "float3":
            return true
        default:
            return false
    }
}

function unifyTypes(typeA: AttributeType, typeB: AttributeType, _flip = true): AttributeType | null {
    if (typeA === typeB) return typeA
    else if (typeA === "float3" && typeB == "float") return "float3"
    else if (typeA === "float2" && typeB == "float") return "float2"
    else if (typeA === "float3" && typeB == "uint32") return "float3"
    else if (typeA === "float2" && typeB == "uint32") return "float2"
    else if (typeA === "float" && typeB == "uint32") return "float"
    else if (_flip) return unifyTypes(typeB, typeA, false)
    else throw new Error(`Cannot unify types: ${typeA} <=> ${typeB}`)
}

function unifyAndPromote(a: AttributeExpr, b: AttributeExpr | number): [AttributeExpr, AttributeExpr] {
    if (typeof b === "number") {
        if (isFloatType(a._type)) {
            b = a.constFloat(b)
        } else {
            b = a.constInt(b)
        }
    }
    const type = unifyTypes(a._type, b._type)
    return [a.promote(type), b.promote(type)]
}

//TODO: remove this when types are unified with render nodes
function isCompilableExpr(expr: OpExpr): expr is CompilableOpExpr {
    return expr.op !== "meshToGeom"
}

function makeOpExpr(op: string, ...args: Expr[]): OpExpr {
    return {op, args}
}

function meshToGeom(meshData: MeshDataGraph): OpExpr {
    return {op: "meshToGeom", args: [meshData]}
}

function makeAttributeExpr(op: string, args: Expr[], type: AttributeType | null) {
    if (type === null) throw new Error("Invalid attribute type")
    return new AttributeExpr(op, args, type)
}

function unaryOp(opName: string, arg: AttributeExpr, newType?: AttributeType) {
    return makeAttributeExpr(opName, [arg._expr], newType !== undefined ? newType : arg._type)
}

function binaryOp(opName: string, scalarOpName: string | null, arg1: AttributeExpr, arg2: AttributeExpr | number, newType?: AttributeType) {
    if (typeof arg2 === "number" && scalarOpName) {
        //TODO: support literal vectors
        return makeAttributeExpr(scalarOpName, [arg1._expr, arg2], newType !== undefined ? newType : arg1._type)
    } else {
        ;[arg1, arg2] = unifyAndPromote(arg1, arg2)
        return makeAttributeExpr(opName, [arg1._expr, arg2._expr], newType !== undefined ? newType : arg1._type)
    }
}

export class AttributeExpr {
    readonly _expr: Expr
    constructor(
        private _op: string,
        private _args: Expr[],
        public _type: AttributeType,
    ) {
        this._expr = makeOpExpr(this._op, ...this._args)
    }

    get x() {
        return makeAttributeExpr("get0", [this._expr], elemType(this._type))
    }
    get y() {
        return makeAttributeExpr("get1", [this._expr], elemType(this._type))
    }
    get z() {
        return makeAttributeExpr("get2", [this._expr], elemType(this._type))
    }
    get xy() {
        return Operators.pack(this.x, this.y)
    }
    get xz() {
        return Operators.pack(this.x, this.z)
    }
    get yz() {
        return Operators.pack(this.y, this.z)
    }
    get yx() {
        return Operators.pack(this.y, this.x)
    }
    get zx() {
        return Operators.pack(this.z, this.x)
    }
    get zy() {
        return Operators.pack(this.z, this.y)
    }
    add(x: AttributeExpr | number) {
        return binaryOp("add", null, this, x)
    }
    sub(x: AttributeExpr | number) {
        return binaryOp("sub", null, this, x)
    }
    mul(x: AttributeExpr | number) {
        return binaryOp("mul", "muls", this, x)
    }
    div(x: AttributeExpr | number) {
        return binaryOp("div", null, this, x)
    }
    mod(x: AttributeExpr | number) {
        return binaryOp("mod", null, this, x)
    }
    gt(x: AttributeExpr | number) {
        return binaryOp("gt", null, this, x)
    }
    lt(x: AttributeExpr | number) {
        return binaryOp("lt", null, this, x)
    }
    eq(x: AttributeExpr | number) {
        return binaryOp("eq", null, this, x)
    }
    ne(x: AttributeExpr | number) {
        return binaryOp("ne", null, this, x)
    }
    not(x: AttributeExpr | number) {
        return binaryOp("not", null, this, x)
    }
    lte(x: AttributeExpr | number) {
        return binaryOp("lte", null, this, x)
    }
    gte(x: AttributeExpr | number) {
        return binaryOp("gte", null, this, x)
    }
    // select(a: AttributeExpr, b: AttributeExpr)
    sin() {
        return unaryOp("sin", this)
    }
    cos() {
        return unaryOp("cos", this)
    }
    norm() {
        return unaryOp("norm", this, elemType(this._type) ?? undefined)
    }
    normalize() {
        return unaryOp("normalize", this)
    }
    sqrt() {
        return unaryOp("sqrt", this)
    }
    recip() {
        return unaryOp("recip", this)
    }
    recipSqrt() {
        return unaryOp("recipSqrt", this)
    }
    exp() {
        return unaryOp("exp", this)
    }
    log() {
        return unaryOp("log", this)
    }
    neg() {
        return this.mul(-1)
    } //TODO: native negation operator

    pack2() {
        return Operators.pack(this, this)
    }
    pack3() {
        return Operators.pack(this, this, this)
    }

    unpack2() {
        return [this.x, this.y] as const
    }
    unpack3() {
        return [this.x, this.y, this.z] as const
    }

    normals(smoothingAttr?: AttributeExpr | number) {
        if (smoothingAttr !== undefined) {
            if (typeof smoothingAttr === "number") {
                smoothingAttr = this.constInt(smoothingAttr)
            }
            return makeAttributeExpr("smoothNormals", [this._expr, smoothingAttr._expr], this._type)
        } else {
            return makeAttributeExpr("flatNormals", [this._expr], this._type)
        }
    }

    constFloat(x: number) {
        return makeAttributeExpr("constFloat", [this._expr, x], "float")
    }
    constInt(x: number) {
        return makeAttributeExpr("constInt", [this._expr, x], "uint32")
    }

    hashFloat() {
        return makeAttributeExpr("hashFloat", [this._expr], "float")
    }
    hashInt() {
        return makeAttributeExpr("hashInt", [this._expr], "uint32")
    }

    castFloat() {
        return makeAttributeExpr("castFloat", [this._expr], "float")
    }
    castInt() {
        return makeAttributeExpr("castInt", [this._expr], "uint32")
    }

    flatten() {
        return makeAttributeExpr("flatten", [this._expr], elemType(this._type))
    }

    promote(type: AttributeType | null) {
        if (this._type === type) return this
        else if (this._type === "float" && type == "float2") return this.pack2()
        else if (this._type === "float" && type == "float3") return this.pack3()
        else if (this._type === "uint32" && type == "float2") return this.castFloat().pack2()
        else if (this._type === "uint32" && type == "float3") return this.castFloat().pack3()
        else throw new Error(`Cannot promote type ${this._type} to ${type}`)
    }
}

function setToListExpr<T>(primaryAttrID: keyof T | undefined, attrs: T): [[string, AttributeType][], Expr] {
    const tokens: [string, AttributeType][] = []
    const args: Expr[] = []
    const selectedKeys: string[] = []
    for (const key in attrs) {
        const attr = (attrs as any)[key]
        const entry: [string, AttributeType] = [key, attr._type]
        if (key === primaryAttrID) {
            tokens.unshift(entry)
            args.unshift(attr._expr)
            selectedKeys.push(key)
        } else {
            tokens.push(entry)
            args.push(attr._expr)
        }
    }
    // if (selectedKeys.length === 0) throw new Error(`No keys selected in set: ${tokens}`);
    // else if (selectedKeys.length > 1) throw new Error(`Multiple keys selected: ${selectedKeys}`);
    return [tokens, makeOpExpr("list", ...args)]
}

function listExprToSet<T>(tokens: [string, AttributeType][], listExpr: Expr): T {
    const set: any = {}
    for (let i = 0; i < tokens.length; i++) {
        set[tokens[i][0]] = makeAttributeExpr("get", [i, listExpr], tokens[i][1])
    }
    return set
}

export namespace Operators {
    export function triangulate<T>(primaryAttrID: keyof T, attrs: T) {
        const [tokens, listExpr] = setToListExpr(primaryAttrID, attrs)
        return listExprToSet<T>(tokens, makeOpExpr("triangulate", listExpr))
    }

    export function join<TA, TB>(primaryAttrIDA: keyof TA, attrSetA: TA, primaryAttrIDB: keyof TB, attrSetB: TB): [TA, TB] {
        const [tokensA, exprListA] = setToListExpr(primaryAttrIDA, attrSetA)
        const [tokensB, exprListB] = setToListExpr(primaryAttrIDB, attrSetB)
        const expr = makeOpExpr("join", exprListA, exprListB)
        return [listExprToSet<TA>(tokensA, makeOpExpr("first", expr)), listExprToSet<TB>(tokensB, makeOpExpr("second", expr))]
    }

    export function product<TA, TB>(primaryAttrIDA: keyof TA, attrSetA: TA, primaryAttrIDB: keyof TB, attrSetB: TB): [TA, TB] {
        const [tokensA, exprListA] = setToListExpr(primaryAttrIDA, attrSetA)
        const [tokensB, exprListB] = setToListExpr(primaryAttrIDB, attrSetB)
        const expr = makeOpExpr("product", exprListA, exprListB)
        return [listExprToSet<TA>(tokensA, makeOpExpr("first", expr)), listExprToSet<TB>(tokensB, makeOpExpr("second", expr))]
    }

    export function disjointProduct<TA, TB>(primaryAttrIDA: keyof TA, attrSetA: TA, primaryAttrIDB: keyof TB, attrSetB: TB): TA & TB {
        const [outA, outB] = product(primaryAttrIDA, attrSetA, primaryAttrIDB, attrSetB)
        return {...outA, ...outB}
    }

    export function filter<T>(primaryAttrID: keyof T, attrs: T) {
        const [tokens, listExpr] = setToListExpr(primaryAttrID, attrs)
        return listExprToSet<T>(tokens, makeOpExpr("filter", listExpr))
    }

    export function boundary<T>(primaryAttrID: keyof T, attrs: T) {
        const [tokens, listExpr] = setToListExpr(primaryAttrID, attrs)
        return listExprToSet<T>(tokens, makeOpExpr("boundary", listExpr))
    }

    export function flip<T>(attrs: T) {
        const [tokens, listExpr] = setToListExpr(undefined, attrs)
        return listExprToSet<T>(tokens, makeOpExpr("flip", listExpr))
    }

    export function merge<T>(...attrSets: T[]) {
        const tokens: [string, AttributeType][] = []
        const listExprs: OpExpr[] = []
        for (const attrSet of attrSets) {
            listExprs.push(makeOpExpr("list"))
        }
        for (const key in attrSets[0]) {
            for (let n = 0; n < attrSets.length; n++) {
                const attr = (attrSets[n] as any)[key]
                listExprs[n].args.push(attr._expr)
                if (n === 0) {
                    tokens.push([key, attr._type])
                }
            }
        }
        return listExprToSet<T>(tokens, makeOpExpr("merge", makeOpExpr("list", ...listExprs)))
    }

    export function pack(arg1: AttributeExpr, arg2: AttributeExpr, arg3?: AttributeExpr) {
        if (arg3 !== undefined) {
            if (arg1._type !== arg2._type || arg2._type != arg3._type)
                throw new Error(`Incompatible attribute types for pack3: ${arg1._type}, ${arg2._type}, ${arg3._type}`)
            if (arg1._type !== "float") throw new Error(`TODO: determine correct non-float type for pack3: ${arg1._type}, ${arg2._type}, ${arg3._type}`)
            return makeAttributeExpr("pack3", [arg1._expr, arg2._expr, arg3._expr], "float3")
        } else {
            if (arg1._type !== arg2._type) throw new Error(`Incompatible attribute types for pack2: ${arg1._type}, ${arg2._type}`)
            if (arg1._type !== "float") throw new Error(`TODO: determine correct non-float type for pack2: ${arg1._type}, ${arg2._type}`)
            return makeAttributeExpr("pack2", [arg1._expr, arg2._expr], "float2") //TODO: determine correct type
        }
    }
}

export namespace Primitives {
    //TODO: empty0, empty1, empty2 operators

    export function grid2(
        originX: number,
        originY: number,
        basis1X: number,
        basis1Y: number,
        basis2X: number,
        basis2Y: number,
        numXSegments: number,
        numYSegments: number,
    ) {
        return makeAttributeExpr("grid2", [originX, originY, basis1X, basis1Y, basis2X, basis2Y, numXSegments, numYSegments], "float2")
    }

    export function linePoints1(x0: number, x1: number, numPoints: number) {
        return makeAttributeExpr("linePoints1", [x0, x1, numPoints], "float")
    }

    export function linePoints2(x0: number, y0: number, x1: number, y1: number, numPoints: number) {
        return makeAttributeExpr("linePoints2", [x0, y0, x1, y1, numPoints], "float2")
    }

    export function linePoints3(x0: number, y0: number, z0: number, x1: number, y1: number, z1: number, numPoints: number) {
        return makeAttributeExpr("linePoints3", [x0, y0, z0, x1, y1, z1, numPoints], "float3")
    }

    export function line1(x0: number, x1: number, numSeg: number) {
        return makeAttributeExpr("line1", [x0, x1, numSeg], "float")
    }

    export function line2(x0: number, y0: number, x1: number, y1: number, numSeg: number) {
        //TODO: doesn't need to be a primitive
        return makeAttributeExpr("line2", [x0, y0, x1, y1, numSeg], "float2")
    }

    export function line3(x0: number, y0: number, z0: number, x1: number, y1: number, z1: number, numSeg: number) {
        //TODO: doesn't need to be a primitive
        return makeAttributeExpr("line3", [x0, y0, z0, x1, y1, z1, numSeg], "float3")
    }

    export function uvSphere(r = 1, numU = 10, numV = 10) {
        //TODO: this should probably be a native primitive?
        //TODO: handle degenerate edges
        const uv = grid2(0.0, 0.0001, 1, 0, 0, 0.9998, numU, numV)
        const phi = uv.x.mul(2 * Math.PI)
        const theta = uv.y.mul(Math.PI)
        const theta_sin = theta.sin()
        const position = Operators.pack(phi.sin().mul(theta_sin).mul(-r), theta.cos().mul(-r), phi.cos().mul(theta_sin).mul(-r))
        return {position, uv}
    }

    //TODO: missing primitives:
    // - Closed 1d line (circle)
    // - Disc
}

export namespace Tilings {
    export function hexagonalTiling(x0: number, y0: number, x1: number, y1: number, r: number, inset: number) {
        const expr = makeOpExpr("hexagonalTiling", x0, y0, x1, y1, r, inset)
        return {
            uv: makeAttributeExpr("get", [0, expr], "float2"),
            tileID: makeAttributeExpr("get", [1, expr], "uint32"),
        }
    }

    export function rectangularTiling(x0: number, y0: number, x1: number, y1: number, w: number, h: number, skew: number, inset: number) {
        const expr = makeOpExpr("rectangularTiling", x0, y0, x1, y1, w, h, skew, inset)
        return {
            uv: makeAttributeExpr("get", [0, expr], "float2"),
            tileID: makeAttributeExpr("get", [1, expr], "uint32"),
        }
    }

    export function dashTiling(x0: number, x1: number, w: number, inset: number) {
        const expr = makeOpExpr("dashTiling", x0, x1, w, inset)
        return {
            u: makeAttributeExpr("get", [0, expr], "float"),
            tileID: makeAttributeExpr("get", [1, expr], "uint32"),
        }
    }
}

export function compileGeometryGraph(meshDataGraph: MeshDataGraph): string {
    const indexMap = new Map<any, number>()
    let programStr = ""
    let entryCounter = 0
    if (meshDataGraph.type !== "geomGraph") {
        throw Error("Cannot evaluate non-geomGraph")
    }
    const traverse = (expr: OpExpr | number) => {
        if (typeof expr === "object") {
            if (expr instanceof ArrayBuffer) {
                throw new Error("TODO: gather ArrayBuffer references")
            }
            let entryIndex = indexMap.get(expr)
            if (entryIndex === undefined) {
                if (!isCompilableExpr(expr)) {
                    throw Error(`Cannot compile expression: ${expr}`)
                }
                const op = expr.op
                const args = expr.args.map(traverse)
                entryIndex = entryCounter++
                indexMap.set(expr, entryIndex)
                programStr += op
                for (const arg of args) {
                    programStr += " "
                    if (typeof arg === "number") {
                        const relative = entryIndex - arg
                        programStr += `${relative}`
                    } else {
                        programStr += arg
                    }
                }
                programStr += ";"
            }
            return entryIndex
        } else if (typeof expr === "number") {
            return `#${expr}`
        } else {
            return expr
        }
    }
    traverse(meshDataGraph.graph)
    return programStr
}

export type StandardGeometryAttributes = {
    position: AttributeExpr
    normal: AttributeExpr
    materialID: AttributeExpr
    uv: AttributeExpr
}

export function packStandardGeometryAttributes(attrs: StandardGeometryAttributes) {
    //TODO: multiple UV channels
    return makeOpExpr("list", attrs.position._expr, attrs.normal._expr, attrs.materialID._expr, attrs.uv._expr)
}

export function unpackStandardGeometryAttributes(meshData: RenderNodes.MeshData): StandardGeometryAttributes {
    return listExprToSet(
        [
            ["position", "float3"],
            ["normal", "float3"],
            ["materialID", "uint32"],
            ["uv", "float2"], //TODO: multiple UV channels
        ],
        meshToGeom(meshData),
    )
}
