// @ts-strict-ignore
import {fitAxisAlignedPlaneToPoints} from "@common/helpers/utils/math-utils"
import {Matrix4, Vector3} from "@common/helpers/vector-math"
import {
    GriddedSurfaceMapData,
    GriddedSurfaceMapSampler,
    ISurfaceMap,
    PlanarSurfaceMapSampler,
    SurfaceOverrides,
} from "@editor/helpers/connection-solver/surface-map"
import {SurfacePointCoordinates, IMeshGeometryAccessor} from "@cm/lib/templates/interfaces/scene-object"

export {ISurfaceMap, SurfaceOverrides}

const defaultMapSize = 128
const minFlatDistance = 1

function sampleMapFromGrid(
    geometryAccessor: IMeshGeometryAccessor,
    matrix: Matrix4,
    triIndices: number[] = undefined,
    projection = true,
): GriddedSurfaceMapData {
    const basis = matrix.toBasis()
    const scaleX = basis[0].norm()
    const scaleY = basis[1].norm()
    const scaleZ = basis[2].norm()
    const maxFlatDistance = Math.max(scaleZ, minFlatDistance)
    const points: number[] = []
    const weights: number[] = []
    let valid: boolean[] = []
    const sizeX = defaultMapSize
    const sizeY = defaultMapSize
    let maxDist = minFlatDistance
    const [dgx, dgy, dgz] = matrix.getNormalXYZ()
    // offset the ray for projection mode by a tiny bit to make sure we hit if we generate points directly on target plane
    const offset = 1e-6
    for (let iy = 0; iy < sizeY; iy++) {
        for (let ix = 0; ix < sizeX; ix++) {
            const tx = ix / (sizeX - 1) - 0.5
            const ty = iy / (sizeY - 1) - 0.5
            const [gx, gy, gz] = matrix.multiplyVectorXYZW(tx, ty, 0, 1)
            let px: number, py: number, pz: number, triIdx: number
            if (projection) {
                // ATTENTION: we add an additional tiny offset to the x position of the point since we discovered an edge case
                //            where the raycast implementation misses a triangle on the top face of the cube test object.
                //            Likely due to floating point precision issues in the triangle intersection test.
                //            When we add the point we use the non offseted one to not bias the result
                const [gxo, gyo, gzo] = matrix.multiplyVectorXYZW(tx + offset, ty, 0, 1)
                ;[px, py, pz, triIdx] = geometryAccessor.getPointOnMeshFromRaycasting(gxo - offset * dgx, gyo - offset * dgy, gzo - offset * dgz, dgx, dgy, dgz)
            } else {
                ;[px, py, pz, triIdx] = geometryAccessor.getClosestPointOnMesh(gx, gy, gz)
            }
            let dist = Math.sqrt((gx - px) ** 2 + (gy - py) ** 2 + (gz - pz) ** 2)
            if (projection && triIndices && !triIndices.includes(triIdx)) dist = maxFlatDistance // reject points hitting outside the provided triangles
            if (dist >= maxFlatDistance) {
                dist = maxFlatDistance
                valid.push(false)
            } else {
                valid.push(true)
            }
            if (projection) {
                points.push(gx + dgx * dist, gy + dgy * dist, gz + dgz * dist)
            } else {
                points.push(px, py, pz)
            }
            weights.push(dist)
            if (dist > maxDist) maxDist = dist
        }
    }
    // fill in invalid points in map by dilation of valid values (fixes issues with interpolation of maps)
    for (let dilationSteps = 0; dilationSteps < 2; dilationSteps++) {
        let idx = 0
        let tryIdx = 0
        const maxX = sizeX - 1
        const maxY = sizeY - 1
        const newValid = valid.slice() // work with copy of map, to avoid values propagating by more than one pixel per iteration
        for (let iy = 0; iy < sizeY; iy++) {
            for (let ix = 0; ix < sizeX; ix++) {
                if (!valid[idx]) {
                    let foundIdx = -1
                    if (ix > 0 && valid[(tryIdx = idx - 1)]) foundIdx = tryIdx
                    else if (ix < maxX && valid[(tryIdx = idx + 1)]) foundIdx = tryIdx
                    else if (iy > 0 && valid[(tryIdx = idx + sizeX)]) foundIdx = tryIdx
                    else if (iy < maxY && valid[(tryIdx = idx - sizeX)]) foundIdx = tryIdx
                    else if (ix > 0 && iy > 0 && valid[(tryIdx = idx - 1 - sizeX)]) foundIdx = tryIdx
                    else if (ix < maxX && iy > 0 && valid[(tryIdx = idx + 1 - sizeX)]) foundIdx = tryIdx
                    else if (ix > 0 && iy < maxY && valid[(tryIdx = idx - 1 + sizeX)]) foundIdx = tryIdx
                    else if (ix < maxX && iy < maxY && valid[(tryIdx = idx + 1 + sizeX)]) foundIdx = tryIdx
                    if (foundIdx >= 0) {
                        points[idx * 3 + 0] = points[foundIdx * 3 + 0]
                        points[idx * 3 + 1] = points[foundIdx * 3 + 1]
                        points[idx * 3 + 2] = points[foundIdx * 3 + 2]
                        newValid[idx] = true
                    }
                }
                ++idx
            }
        }
        valid = newValid
    }
    const distScale = 1 / Math.max(0.1, scaleZ)
    let weightSum = 0
    for (let i = 0; i < weights.length; i++) {
        weights[i] = 1 - Math.min(1, weights[i] * distScale)
        weightSum += weights[i]
    }
    if (weightSum <= 1e-3) {
        weightSum = 0
        for (let i = 0; i < weights.length; i++) {
            weights[i] = 1
            weightSum += weights[i]
        }
    }
    const weightSc = 1 / Math.max(1e-3, weightSum)
    const objCentroid = [0, 0, 0]
    const mapCentroid = [0, 0]
    let i = 0
    for (let iy = 0; iy < sizeY; iy++) {
        for (let ix = 0; ix < sizeX; ix++) {
            const wsc = weights[i] * weightSc
            objCentroid[0] += points[i * 3 + 0] * wsc
            objCentroid[1] += points[i * 3 + 1] * wsc
            objCentroid[2] += points[i * 3 + 2] * wsc
            mapCentroid[0] += ix * wsc
            mapCentroid[1] += iy * wsc
            ++i
        }
    }
    return {
        points,
        weights,
        matrix,
        sizeX,
        sizeY,
        scaleX,
        scaleY,
        centerX: (mapCentroid[0] * scaleX) / (sizeX - 1),
        centerY: (mapCentroid[1] * scaleY) / (sizeY - 1),
    }
}

function surfaceMatrixFromPoints(pointCoords: number[], normalMean: Vector3, scaleFac = 1.0, surfaceOverrides: SurfaceOverrides): Matrix4 {
    const plane = fitAxisAlignedPlaneToPoints(pointCoords, surfaceOverrides.rotationAngle)
    const ax = Vector3.fromArray(plane.axes[0])
    const ay = Vector3.fromArray(plane.axes[1])
    const az = ax.cross(ay)
    const scaleOffsetXY = Math.min(
        Math.abs(plane.scale[0]) * scaleFac - Math.abs(plane.scale[0]),
        Math.abs(plane.scale[1]) * scaleFac - Math.abs(plane.scale[1]),
    )
    const scaleX = Math.abs(plane.scale[0]) + scaleOffsetXY
    const scaleY = Math.abs(plane.scale[1]) + scaleOffsetXY
    let scaleZ = Math.abs(plane.scale[2]) * scaleFac
    const minScaleZ = 1.0
    if (scaleZ < minScaleZ) scaleZ = minScaleZ
    if (normalMean.dot(az) < 0 !== surfaceOverrides.flipNormals) {
        ax.x *= -1
        ax.y *= -1
        ax.z *= -1
        az.x *= -1
        az.y *= -1
        az.z *= -1
    }

    // push out the plane if the projection of a vertex onto the plane normal is positive
    // this should place the plane outside the vertices
    const normalPlane = az.normalized()
    for (let i = 0; i < pointCoords.length; i += 3) {
        const v = new Vector3(pointCoords[i + 0], pointCoords[i + 1], pointCoords[i + 2])
        const fac = normalPlane.dot(v.sub(new Vector3(plane.origin[0], plane.origin[1], plane.origin[2])))
        if (fac > 0) {
            plane.origin[0] += fac * normalPlane.x
            plane.origin[1] += fac * normalPlane.y
            plane.origin[2] += fac * normalPlane.z
        }
    }

    // apply Z-scale override
    scaleZ *= surfaceOverrides.scaleZ

    let matrix = Matrix4.translation(plane.origin[0], plane.origin[1], plane.origin[2])
    matrix = matrix.multiply(Matrix4.fromBasis(ax, ay, az))
    matrix = matrix.multiply(Matrix4.scaling(scaleX, scaleY, scaleZ))

    return matrix
}

export function surfaceMapForMultiplePoints(
    geometryAccessor: IMeshGeometryAccessor,
    inPoints: SurfacePointCoordinates[],
    projectionMode: boolean,
    matrix: Matrix4,
    surfaceOverrides: SurfaceOverrides,
): ISurfaceMap {
    if (matrix === undefined) {
        const pointCoords: number[] = []
        let normalMean = new Vector3(0, 0, 0)

        for (const pt of inPoints) {
            const normal = geometryAccessor.interpolateTriangleNormal(pt)
            normalMean.x += normal[0]
            normalMean.y += normal[1]
            normalMean.z += normal[2]
            pointCoords.push(pt[0], pt[1], pt[2])
        }
        normalMean = normalMean.normalized()

        matrix = surfaceMatrixFromPoints(pointCoords, normalMean, 1.1, surfaceOverrides)
    }
    const fullMap = sampleMapFromGrid(geometryAccessor, matrix, undefined, projectionMode)
    if (surfaceOverrides.centroidOffset) {
        fullMap.centerX += surfaceOverrides.centroidOffset[0]
        fullMap.centerY += surfaceOverrides.centroidOffset[1]
    }
    return new GriddedSurfaceMapSampler(fullMap)
}

export function surfaceMapForMultipleFaces(
    geometryAccessor: IMeshGeometryAccessor,
    inFaceIDs: number[],
    projectionMode: boolean,
    matrix: Matrix4,
    surfaceOverrides: SurfaceOverrides,
): ISurfaceMap {
    const triList = geometryAccessor.faceIDsToTriangleIndices(inFaceIDs)
    if (matrix === undefined) {
        const pointCoords: [number, number, number][] = []
        let normalMean = new Vector3(0, 0, 0)

        for (const triIdx of triList) {
            const [v1, v2, v3] = geometryAccessor.getVerticesForTriangle(triIdx)
            pointCoords.push([v1.x, v1.y, v1.z])
            pointCoords.push([v2.x, v2.y, v2.z])
            pointCoords.push([v3.x, v3.y, v3.z])

            const [n1, n2, n3] = geometryAccessor.getNormalsForTriangle(triIdx)
            normalMean = normalMean.add(Vector3.fromIVector3(n1))
            normalMean = normalMean.add(Vector3.fromIVector3(n2))
            normalMean = normalMean.add(Vector3.fromIVector3(n3))
        }
        normalMean = normalMean.normalized()

        // filter out duplicate vertex points
        const pointCoordsFiltered: number[] = []
        for (const point of pointCoords) {
            let duplicate = false
            for (let i = 0; i < pointCoordsFiltered.length; i += 3) {
                if (point[0] == pointCoordsFiltered[i + 0] && point[1] == pointCoordsFiltered[i + 1] && point[2] == pointCoordsFiltered[i + 2]) duplicate = true
            }
            if (!duplicate) {
                pointCoordsFiltered.push(point[0], point[1], point[2])
            }
        }

        matrix = surfaceMatrixFromPoints(pointCoordsFiltered, normalMean, 1.05, surfaceOverrides)
    }
    const fullMap = sampleMapFromGrid(geometryAccessor, matrix, triList, projectionMode)
    if (surfaceOverrides.centroidOffset) {
        fullMap.centerX += surfaceOverrides.centroidOffset[0]
        fullMap.centerY += surfaceOverrides.centroidOffset[1]
    }
    return new GriddedSurfaceMapSampler(fullMap)
}

export function surfaceMapForPlane(matrix: Matrix4, scaleX: number, scaleY: number): ISurfaceMap {
    return new PlanarSurfaceMapSampler(matrix, scaleX, scaleY)
}
