package org.sunflow.raytracer;

import org.sunflow.image.Color;
import org.sunflow.math.MathUtils;
import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;

class IrradianceCache {
    private double tolerance;
    private double invTolerance;
    private double minSpacing;
    private double maxSpacing;
    private Node root;

    IrradianceCache(double tolerance, double minSpacing, BoundingBox sceneBounds) {
        this.tolerance = tolerance;
        invTolerance = 1.0 / tolerance;
        this.minSpacing = minSpacing;
        maxSpacing = 100.0 * minSpacing;
        Vector3 ext = sceneBounds.getExtents();
        root = new Node(sceneBounds.getCenter(), 1.0001 * MathUtils.max(ext.x, ext.y, ext.z));
    }

    void insert(Point3 p, Vector3 n, double r0, Color irr, Color[] rotGradient, Color[] transGradient) {
        Node node = root;
        r0 = MathUtils.clamp(r0 * tolerance, minSpacing, maxSpacing) * invTolerance;
        if (root.isInside(p)) {
            while (node.sideLength >= (4.0 * r0 * tolerance)) {
                int k = 0;
                k |= (p.x > node.center.x) ? 1 : 0;
                k |= (p.y > node.center.y) ? 2 : 0;
                k |= (p.z > node.center.z) ? 4 : 0;
                if (node.children[k] == null) {
                    Point3 c = new Point3(node.center);
                    c.x += ((k & 1) == 0) ? -node.quadSideLength : node.quadSideLength;
                    c.y += ((k & 2) == 0) ? -node.quadSideLength : node.quadSideLength;
                    c.z += ((k & 4) == 0) ? -node.quadSideLength : node.quadSideLength;
                    node.children[k] = new Node(c, node.halfSideLength);
                }
                node = node.children[k];
            }
        }
        Sample s = new Sample(p, n, r0, irr, rotGradient, transGradient);
        s.next = node.first;
        node.first = s;
    }

    Color getIrradiance(Point3 p, Vector3 n) {
        Sample x = new Sample(p, n);
        double w = root.find(x);
        return (x.irr == null) ? null : x.irr.mul(1.0 / w);
    }

    private final class Node {
        Node[] children;
        Sample first;
        Point3 center;
        double sideLength;
        double halfSideLength;
        double quadSideLength;

        Node(Point3 center, double sideLength) {
            children = new Node[8];
            for (int i = 0; i < 8; i++)
                children[i] = null;
            this.center = new Point3(center);
            this.sideLength = sideLength;
            halfSideLength = 0.5 * sideLength;
            quadSideLength = 0.5 * halfSideLength;
            first = null;
        }

        final boolean isInside(Point3 p) {
            return (Math.abs(p.x - center.x) < halfSideLength) && (Math.abs(p.y - center.y) < halfSideLength) && (Math.abs(p.z - center.z) < halfSideLength);
        }

        final double find(Sample x) {
            double weight = 0.0;
            for (Sample s = first; s != null; s = s.next) {
                double wi = Math.min(1e10, s.weight(x));
                if (wi > invTolerance) {
                    if (x.irr != null)
                        x.irr.madd(wi, s.getIrradiance(x));
                    else
                        x.irr = s.getIrradiance(x).mul(wi);
                    weight += wi;
                }
            }
            for (int i = 0; i < 8; i++)
                if ((children[i] != null) && (Math.abs(children[i].center.x - x.pi.x) <= halfSideLength) && (Math.abs(children[i].center.y - x.pi.y) <= halfSideLength) && (Math.abs(children[i].center.z - x.pi.z) <= halfSideLength))
                    weight += children[i].find(x);
            return weight;
        }
    }

    private static final class Sample {
        Point3 pi;
        Vector3 ni;
        double invR0;
        Color irr;
        Color[] rotGradient;
        Color[] transGradient;
        Sample next;

        Sample(Point3 p, Vector3 n) {
            pi = new Point3(p);
            ni = new Vector3(n).normalize();
            irr = null;
            next = null;
        }

        Sample(Point3 p, Vector3 n, double r0, Color irr, Color[] rotGradient, Color[] transGradient) {
            pi = new Point3(p);
            ni = new Vector3(n).normalize();
            invR0 = 1.0 / r0;
            this.irr = new Color(irr);
            this.rotGradient = rotGradient;
            this.transGradient = transGradient;
            next = null;
        }

        final double weight(Sample x) {
            return 1.0 / ((x.pi.distanceTo(pi) * invR0) + Math.sqrt(1.0 - Math.min(1.0, Vector3.dot(x.ni, ni))));
        }

        final Color getIrradiance(Sample x) {
            Color temp = new Color(irr);
            temp.madd(x.pi.x - pi.x, transGradient[0]);
            temp.madd(x.pi.y - pi.y, transGradient[1]);
            temp.madd(x.pi.z - pi.z, transGradient[2]);
            Vector3 cross = Vector3.cross(ni, x.ni, new Vector3());
            temp.madd(cross.x, rotGradient[0]);
            temp.madd(cross.y, rotGradient[1]);
            temp.madd(cross.z, rotGradient[2]);
            return temp;
        }
    }
}