import {createRNG, matInverse4, matVec4Mul, zeros} from "@common/helpers/utils/math-utils"
import {Matrix4} from "@cm/lib/math"
import {Matrix, Tensor} from "@cm/lib/templates/interfaces/connection-solver"

export interface ISurfaceMap {
    readonly matrix: Matrix4
    readonly centerX: number
    readonly centerY: number
    readonly scaleX: number
    readonly scaleY: number
    generateImportanceSampledPoint(rngFn: () => number): [number, number, number]
    generateSamplingPattern(otherMap: ISurfaceMap, matrixCoordAtoCoordB: Matrix): Matrix
    samplePointWithCoordD(coord: ArrayLike<number>, outPoint: number[], dPt_dCoord: Tensor): void
    sampleWeightWithCoordD(coord: ArrayLike<number>, dw_dCoord: number[]): number
}

export type GriddedSurfaceMapData = {
    points: number[]
    weights: number[]
    matrix: Matrix4
    sizeX: number
    sizeY: number
    scaleX: number
    scaleY: number
    centerX: number
    centerY: number
}

export type SurfaceOverrides = {
    rotationAngle: number
    scaleZ: number
    centroidOffset: [number, number]
    flipNormals?: boolean
}

function dilateMapValues(map: number[], sizeX: number, sizeY: number) {
    let valid: boolean[] = []
    let idx = 0
    for (let iy = 0; iy < sizeY; iy++) {
        for (let ix = 0; ix < sizeX; ix++) {
            valid.push(map[4 * idx + 3] > 0)
            ++idx
        }
    }
    for (let dilationSteps = 0; dilationSteps < 1; dilationSteps++) {
        idx = 0
        let tryIdx = 0
        const maxX = sizeX - 1
        const maxY = sizeY - 1
        const newValid = valid.slice() // work with copy of map, to avoid values propagating by more than one pixel per iteration
        for (let iy = 0; iy < sizeY; iy++) {
            for (let ix = 0; ix < sizeX; ix++) {
                if (!valid[idx]) {
                    let foundIdx = -1
                    if (ix > 0 && valid[(tryIdx = idx - 1)]) foundIdx = tryIdx
                    else if (ix < maxX && valid[(tryIdx = idx + 1)]) foundIdx = tryIdx
                    else if (iy > 0 && valid[(tryIdx = idx + sizeX)]) foundIdx = tryIdx
                    else if (iy < maxY && valid[(tryIdx = idx - sizeX)]) foundIdx = tryIdx
                    else if (ix > 0 && iy > 0 && valid[(tryIdx = idx - 1 - sizeX)]) foundIdx = tryIdx
                    else if (ix < maxX && iy > 0 && valid[(tryIdx = idx + 1 - sizeX)]) foundIdx = tryIdx
                    else if (ix > 0 && iy < maxY && valid[(tryIdx = idx - 1 + sizeX)]) foundIdx = tryIdx
                    else if (ix < maxX && iy < maxY && valid[(tryIdx = idx + 1 + sizeX)]) foundIdx = tryIdx
                    if (foundIdx >= 0) {
                        map[idx * 4 + 0] = map[foundIdx * 4 + 0]
                        map[idx * 4 + 1] = map[foundIdx * 4 + 1]
                        map[idx * 4 + 2] = map[foundIdx * 4 + 2]
                        map[idx * 4 + 3] = 0
                        newValid[idx] = true
                    }
                }
                ++idx
            }
        }
        valid = newValid
    }
}

export function renderMap(map: ISurfaceMap, sizeX: number, sizeY: number): number[] {
    const dw_dCoord = [0, 0, 0, 0]
    const coords = [0, 0, 0, 0]
    const point = [0, 0, 0, 0]
    const dPt_dCoord = zeros(2, 4)
    const sx = map.scaleX / (sizeX - 1)
    const sy = map.scaleY / (sizeY - 1)
    const ox = -map.scaleX / 2
    const oy = -map.scaleY / 2
    const outMap: number[] = []
    for (let y = 0; y < sizeY; y++) {
        for (let x = 0; x < sizeY; x++) {
            coords[0] = x * sx + ox
            coords[1] = y * sy + oy
            map.samplePointWithCoordD(coords, point, dPt_dCoord)
            point[3] = map.sampleWeightWithCoordD(coords, dw_dCoord)
            outMap.push(...point)
        }
    }
    dilateMapValues(outMap, sizeX, sizeY)
    return outMap
}

/** Downsample a map by a factor of 2 in both dimensions. Channel 0,1,2 are averaged together using channel 3 as a weight. */
function downsampleMap(map: Tensor) {
    const w = map.shape[1]
    const h = map.shape[0]
    const ow = Math.floor(w / 2)
    const oh = Math.floor(h / 2)
    const out = zeros(ow, oh, 4)
    for (let oy = 0; oy < oh; oy++) {
        for (let ox = 0; ox < ow; ox++) {
            const ix0 = ox * 2
            const iy0 = oy * 2
            const ix1 = ix0 + 1
            const iy1 = iy0 + 1
            const wx0y0 = map.get(iy0, ix0, 3)
            const wx0y1 = map.get(iy1, ix0, 3)
            const wx1y0 = map.get(iy0, ix1, 3)
            const wx1y1 = map.get(iy1, ix1, 3)
            const sumW = wx0y0 + wx0y1 + wx1y0 + wx1y1
            out.set(oy, ox, 3, sumW * 0.25)
            const iSumW = 1 / (1e-8 + sumW)
            for (let c = 0; c < 3; c++) {
                const cx0y0 = map.get(iy0, ix0, c)
                const cx0y1 = map.get(iy1, ix0, c)
                const cx1y0 = map.get(iy0, ix1, c)
                const cx1y1 = map.get(iy1, ix1, c)
                const sumC = cx0y0 * wx0y0 + cx0y1 * wx0y1 + cx1y0 * wx1y0 + cx1y1 * wx1y1
                out.set(oy, ox, c, sumC * iSumW)
            }
        }
    }
    return out
}

/** Repeatedly downsample a map until a minimum size has been reached, and return the set of downsampled maps as an array. */
function buildMapLevels(map: Tensor): Tensor[] {
    const levels: Tensor[] = []
    let curLevel = map
    levels.unshift(curLevel)
    while (curLevel.shape[0] >= 4 && curLevel.shape[1] >= 4) {
        curLevel = downsampleMap(curLevel)
        levels.unshift(curLevel)
    }
    return levels
}

function generateSamplingPatternBiDirectionalGrid(mapA: ISurfaceMap, mapB: ISurfaceMap, matrixCoordAtoCoordB: Matrix): Matrix {
    // returned matrix is [sampleIdx,coordIdx]
    const numXSamples = 64 //mapA.sizeX;
    const numYSamples = 64 //mapA.sizeY;
    const retSamples = zeros(2 * numXSamples * numYSamples, 4)
    let idx = 0

    let ox = -mapA.centerX
    let oy = -mapA.centerY
    let sx = mapA.scaleX / (numXSamples - 1)
    let sy = mapA.scaleY / (numYSamples - 1)
    for (let y = 0; y < numYSamples; y++) {
        for (let x = 0; x < numXSamples; x++) {
            retSamples.set(idx, 0, x * sx + ox)
            retSamples.set(idx, 1, y * sy + oy)
            retSamples.set(idx, 3, 1)
            ++idx
        }
    }

    const matrixCoordBtoCoordA: Matrix = zeros(4, 4)
    matInverse4(matrixCoordAtoCoordB, matrixCoordBtoCoordA)
    const coordA = [0, 0, 0, 1]
    const coordB = [0, 0, 0, 1]
    ox = -mapB.centerX
    oy = -mapB.centerY
    sx = mapB.scaleX / (numXSamples - 1)
    sy = mapB.scaleY / (numYSamples - 1)
    for (let y = 0; y < numYSamples; y++) {
        for (let x = 0; x < numXSamples; x++) {
            coordB[0] = x * sx + ox
            coordB[1] = y * sy + oy
            matVec4Mul(matrixCoordBtoCoordA, coordB[0], coordB[1], 0, 1, coordA)
            retSamples.set(idx, 0, coordA[0])
            retSamples.set(idx, 1, coordA[1])
            retSamples.set(idx, 3, 1)
            ++idx
        }
    }

    console.log(retSamples)

    return retSamples
}

function generateSamplingPattern(mapA: ISurfaceMap, mapB: ISurfaceMap, matrixCoordAtoCoordB: Matrix): Matrix {
    // returned matrix is [sampleIdx,coordIdx]
    const numSamplesPerMap = 256
    const retSamples = zeros(2 * numSamplesPerMap, 4)
    let idx = 0

    const rngFn = createRNG("seedString1234") //TODO: base seed on map IDs?

    for (let n = 0; n < numSamplesPerMap; n++) {
        const sample = mapA.generateImportanceSampledPoint(rngFn)
        retSamples.set(idx, 0, sample[0])
        retSamples.set(idx, 1, sample[1])
        retSamples.set(idx, 2, sample[2])
        retSamples.set(idx, 3, 1.0)
        ++idx
    }

    const matrixCoordBtoCoordA: Matrix = zeros(4, 4)
    matInverse4(matrixCoordAtoCoordB, matrixCoordBtoCoordA)
    const coordA = [0, 0, 0, 1]
    for (let n = 0; n < numSamplesPerMap; n++) {
        const sample = mapB.generateImportanceSampledPoint(rngFn)
        matVec4Mul(matrixCoordBtoCoordA, sample[0], sample[1], 0, 1, coordA)
        retSamples.set(idx, 0, coordA[0])
        retSamples.set(idx, 1, coordA[1])
        retSamples.set(idx, 2, 1.0)
        retSamples.set(idx, 3, sample[2])
        ++idx
    }

    return retSamples
}

/** Utility class for sampling maps with interpolation using world units/dimensions */
export class GriddedSurfaceMapSampler implements ISurfaceMap {
    private readonly sizeX: number
    private readonly sizeY: number
    readonly scaleX: number
    readonly scaleY: number
    private readonly ratioX: number
    private readonly ratioY: number
    readonly centerX: number
    readonly centerY: number
    private readonly map: Tensor
    private readonly mapLevels: Tensor[]
    readonly matrix: Matrix4

    constructor(data: GriddedSurfaceMapData) {
        const points = data.points
        const weights = data.weights
        const nx = data.sizeX
        const ny = data.sizeY
        const topLevelMap = zeros(ny, nx, 4)
        let pIdx = 0
        let wIdx = 0
        for (let y = 0; y < ny; y++) {
            for (let x = 0; x < nx; x++) {
                topLevelMap.set(y, x, 0, points[pIdx++])
                topLevelMap.set(y, x, 1, points[pIdx++])
                topLevelMap.set(y, x, 2, points[pIdx++])
                topLevelMap.set(y, x, 3, weights[wIdx++])
            }
        }
        this.mapLevels = buildMapLevels(topLevelMap)
        this.map = this.mapLevels[this.mapLevels.length - 1]
        this.matrix = data.matrix
        this.sizeX = this.map.shape[1]
        this.sizeY = this.map.shape[0]
        this.scaleX = data.scaleX
        this.scaleY = data.scaleY
        this.ratioX = (this.sizeX - 1) / data.scaleX
        this.ratioY = (this.sizeY - 1) / data.scaleY
        this.centerX = data.centerX
        this.centerY = data.centerY
    }

    // generateImportanceSampledPoint(rangeX: [number, number], rangeY: [number, number]): ArrayLike<number> {
    //     let sumW = 0;
    //     for (let iy = Math.floor(rangeY[0]*this.sizeY); iy < rangeY[1]*this.sizeY; iy++) {
    //         for (let ix = Math.floor(rangeX[0]*this.sizeY); ix < rangeX[1]*this.sizeX; ix++) {
    //             sumW += this.map.get(iy, ix, 3);
    //         }
    //     }

    //     const sample = [0, 0, 0];
    //     let sumR = Math.random() * sumW;
    //     for (let iy = Math.floor(rangeY[0]*this.sizeY); iy < rangeY[1]*this.sizeY; iy++) {
    //         for (let ix = Math.floor(rangeX[0]*this.sizeY); ix < rangeX[1]*this.sizeX; ix++) {
    //             sumR -= this.map.get(iy, ix, 3);
    //             if (sumR <= 0.0) {
    //                 sample[0] = ix/(this.sizeX - 1);
    //                 sample[1] = iy/(this.sizeY - 1);
    //                 sample[2] = this.map.get(iy, ix, 3);
    //                 break;
    //             }
    //         }
    //         if (sumR <= 0.0) break;
    //     }

    //     return sample;
    // }

    generateImportanceSampledPoint(rngFn: () => number): [number, number, number] {
        // start from the lowest level downsampled map, descend the pyramid looking at 2x2 groups. (This is like sampling from a quadtree)
        let x = 0
        let y = 0
        const numLevels = this.mapLevels.length
        for (let levelIdx = 0; levelIdx < numLevels; levelIdx++) {
            x *= 2
            y *= 2
            const mapLevel = this.mapLevels[levelIdx]
            const w00 = mapLevel.get(y, x, 3)
            let w10 = mapLevel.get(y, x + 1, 3)
            let w01 = mapLevel.get(y + 1, x, 3)
            let w11 = mapLevel.get(y + 1, x + 1, 3)
            const wSum = w00 + w01 + w10 + w11
            w10 += w00
            w01 += w10
            w11 += w01
            const r = rngFn() * wSum // TODO: use same seed for every evaluation!
            if (r <= w00) {
            } else if (r <= w10) {
                ++x
            } else if (r <= w01) {
                ++y
            } else {
                ++x
                ++y
            }
        }
        const sumTotalW =
            4.0 *
            (numLevels - 1) *
            (this.mapLevels[0].get(0, 0, 3) + this.mapLevels[0].get(1, 0, 3) + this.mapLevels[0].get(0, 1, 3) + this.mapLevels[0].get(1, 1, 3)) // total sum of w based on highest mapLevel
        return [(x + rngFn()) / this.ratioX - this.centerX, (y + rngFn()) / this.ratioY - this.centerY, this.map.get(y, x, 3) / sumTotalW]
    }

    generateSamplingPattern(otherMap: ISurfaceMap, matrixCoordAtoCoordB: Matrix): Matrix {
        return generateSamplingPattern(this, otherMap, matrixCoordAtoCoordB)
    }

    /** Sample map point data (channels 0, 1, 2), additionally returning the derivative of the interpolated data with respect to the input coordinates. */
    samplePointWithCoordD(coord: ArrayLike<number>, outPoint: number[], dPt_dCoord: Tensor): void {
        const data = this.map
        const x = (coord[0] + this.centerX) * this.ratioX
        const y = (coord[1] + this.centerY) * this.ratioY
        const ix = Math.floor(x)
        const iy = Math.floor(y)
        if (ix < 0 || ix >= this.sizeX - 1 || iy < 0 || iy >= this.sizeY - 1) {
            outPoint[0] = 0
            outPoint[1] = 0
            outPoint[2] = 0
            outPoint[3] = 0
            ;(dPt_dCoord.data as any).fill(0)
            return
        }
        const fx = x - ix
        const fy = y - iy
        const dfx_dcx = this.ratioX
        const dfy_dcy = this.ratioY
        const v00_0 = data.get(iy, ix, 0)
        const v10_0 = data.get(iy, ix + 1, 0)
        const v01_0 = data.get(iy + 1, ix, 0)
        const v11_0 = data.get(iy + 1, ix + 1, 0)
        const v00_1 = data.get(iy, ix, 1)
        const v10_1 = data.get(iy, ix + 1, 1)
        const v01_1 = data.get(iy + 1, ix, 1)
        const v11_1 = data.get(iy + 1, ix + 1, 1)
        const v00_2 = data.get(iy, ix, 2)
        const v10_2 = data.get(iy, ix + 1, 2)
        const v01_2 = data.get(iy + 1, ix, 2)
        const v11_2 = data.get(iy + 1, ix + 1, 2)
        const a00_0 = v00_0
        const a10_0 = v10_0 - v00_0
        const a01_0 = v01_0 - v00_0
        const a11_0 = v11_0 + v00_0 - v10_0 - v01_0
        const a00_1 = v00_1
        const a10_1 = v10_1 - v00_1
        const a01_1 = v01_1 - v00_1
        const a11_1 = v11_1 + v00_1 - v10_1 - v01_1
        const a00_2 = v00_2
        const a10_2 = v10_2 - v00_2
        const a01_2 = v01_2 - v00_2
        const a11_2 = v11_2 + v00_2 - v10_2 - v01_2
        outPoint[0] = a00_0 + a10_0 * fx + a01_0 * fy + a11_0 * fx * fy
        outPoint[1] = a00_1 + a10_1 * fx + a01_1 * fy + a11_1 * fx * fy
        outPoint[2] = a00_2 + a10_2 * fx + a01_2 * fy + a11_2 * fx * fy
        outPoint[3] = 1.0
        dPt_dCoord.set(0, 0, (a10_0 + a11_0 * fy) * dfx_dcx)
        dPt_dCoord.set(0, 1, (a10_1 + a11_1 * fy) * dfx_dcx)
        dPt_dCoord.set(0, 2, (a10_2 + a11_2 * fy) * dfx_dcx)
        dPt_dCoord.set(1, 0, (a01_0 + a11_0 * fx) * dfy_dcy)
        dPt_dCoord.set(1, 1, (a01_1 + a11_1 * fx) * dfy_dcy)
        dPt_dCoord.set(1, 2, (a01_2 + a11_2 * fx) * dfy_dcy)
    }

    /** Sample map weight data (channel 3), additionally returning the derivative of the interpolated weight with respect to the input coordinates. */
    sampleWeightWithCoordD(coord: ArrayLike<number>, dw_dCoord: number[]): number {
        const data = this.map
        const x = (coord[0] + this.centerX) * this.ratioX
        const y = (coord[1] + this.centerY) * this.ratioY
        const ix = Math.floor(x)
        const iy = Math.floor(y)
        if (ix < 0 || ix >= this.sizeX - 1 || iy < 0 || iy >= this.sizeY - 1) {
            dw_dCoord[0] = 0
            dw_dCoord[1] = 0
            dw_dCoord[2] = 0
            dw_dCoord[3] = 0
            return 0
        }
        const fx = x - ix
        const fy = y - iy
        const dfx_dcx = this.ratioX
        const dfy_dcy = this.ratioY
        const v00 = data.get(iy, ix, 3)
        const v10 = data.get(iy, ix + 1, 3)
        const v01 = data.get(iy + 1, ix, 3)
        const v11 = data.get(iy + 1, ix + 1, 3)
        const a00 = v00
        const a10 = v10 - v00
        const a01 = v01 - v00
        const a11 = v11 + v00 - v10 - v01
        const w = a00 + a10 * fx + a01 * fy + a11 * fx * fy
        dw_dCoord[0] = (a10 + a11 * fy) * dfx_dcx
        dw_dCoord[1] = (a01 + a11 * fx) * dfy_dcy
        dw_dCoord[2] = 0.0
        dw_dCoord[3] = 0.0
        return w
    }
}

/** Utility class for sampling maps with interpolation using world units/dimensions */
export class PlanarSurfaceMapSampler implements ISurfaceMap {
    readonly centerX: number
    readonly centerY: number

    constructor(
        readonly matrix: Matrix4,
        readonly scaleX: number,
        readonly scaleY: number,
    ) {
        this.centerX = scaleX / 2
        this.centerY = scaleY / 2
    }

    generateImportanceSampledPoint(rngFn: () => number): [number, number, number] {
        return [
            // TODO: use same seed for every evaluation!
            this.scaleX * (rngFn() - 0.5),
            this.scaleY * (rngFn() - 0.5),
            1.0,
        ]
    }

    generateSamplingPattern(otherMap: ISurfaceMap, matrixCoordAtoCoordB: Matrix): Matrix {
        return generateSamplingPattern(this, otherMap, matrixCoordAtoCoordB)
    }

    samplePointWithCoordD(coord: ArrayLike<number>, outPoint: number[], dPt_dCoord: Tensor): void {
        const pt = this.matrix.multiplyVectorXYZW(coord[0], coord[1], 0, 1)
        const dpt_dcx = this.matrix.multiplyVectorXYZW(1, 0, 0, 0)
        const dpt_dcy = this.matrix.multiplyVectorXYZW(0, 1, 0, 0)
        outPoint[0] = pt[0]
        outPoint[1] = pt[1]
        outPoint[2] = pt[2]
        outPoint[3] = pt[3]
        dPt_dCoord.set(0, 0, dpt_dcx[0])
        dPt_dCoord.set(0, 1, dpt_dcx[1])
        dPt_dCoord.set(0, 2, dpt_dcx[2])
        dPt_dCoord.set(1, 0, dpt_dcy[0])
        dPt_dCoord.set(1, 1, dpt_dcy[1])
        dPt_dCoord.set(1, 2, dpt_dcy[2])
    }

    /** Sample map weight data (channel 3), additionally returning the derivative of the interpolated weight with respect to the input coordinates. */
    sampleWeightWithCoordD(coord: ArrayLike<number>, dw_dCoord: number[]): number {
        dw_dCoord[0] = 0.0
        dw_dCoord[1] = 0.0
        dw_dCoord[2] = 0.0
        dw_dCoord[3] = 0.0
        return 1.0
    }
}
