package org.sunflow.raytracer.photonmap;

import org.sunflow.image.Color;
import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.RenderState;

public final class GlobalPhotonMap extends PhotonMap {
    private boolean hasIrradiance;
    private double maxPower;
    private double maxRadius;

    public GlobalPhotonMap(int maxPhotons, int numGather, double gatherRadius) {
        super(maxPhotons, numGather, gatherRadius);
        hasIrradiance = false;
        maxPower = 0;
        maxRadius = 0;
    }

    public void storePhoton(RenderState state, Vector3 dir, Color power, boolean isDirect, boolean isCaustic) {
        super.storePhoton(new IrradiancePhoton(state.getVertex().p, state.getVertex().n, dir, power, isDirect, isCaustic));
        maxPower = Math.max(maxPower, power.getMax());
    }

    public void initialize(double scale) {
        balance();
        scalePhotonPower(scale);
        maxPower *= scale;
        maxRadius = 1.4 * Math.sqrt(maxPower * gatherNum);
        if (gatherRadius > maxRadius)
            gatherRadius = maxRadius;
    }

    public void precomputeIrradiance(boolean includeDirect, boolean includeCaustics) {
        if (size() == 0)
            return;

        // precompute the indirect irradiance for all photons that are neither
        // leaves nor parents of leaves in the tree.
        int quadStoredPhotons = halfStoredPhotons / 2;
        Point3 p = new Point3();
        Vector3 n = new Vector3();
        Point3 ppos = new Point3();
        Vector3 pdir = new Vector3();
        Vector3 pvec = new Vector3();
        Color irr = new Color();
        Color pow = new Color();
        double maxDist2 = gatherRadius * gatherRadius;
        NearestPhotons np = new NearestPhotons(p, gatherNum, maxDist2);
        Photon[] temp = new Photon[quadStoredPhotons + 1];
        for (int i = 1; i <= quadStoredPhotons; i++) {
            IrradiancePhoton curr = (IrradiancePhoton) photons[i];
            p.set(curr.x, curr.y, curr.z);
            getUnitVector(curr.normalTheta, curr.normalPhi, n);
            irr.set(Color.BLACK);
            np.reset(p, maxDist2);
            locatePhotons(np);
            double invArea = 1.0 / (Math.PI * np.dist2[0]);
            double maxNDist = np.dist2[0] * 0.05;
            for (int j = 1; j <= np.found; j++) {
                IrradiancePhoton phot = (IrradiancePhoton) np.index[j];
                if (!includeDirect && phot.isDirect())
                    continue;
                if (!includeCaustics && phot.isCaustic())
                    continue;
                getUnitVector(phot.dirTheta, phot.dirPhi, pdir);
                double cos = -Vector3.dot(pdir, n);
                if (cos > 0.01) {
                    ppos.set(phot.x, phot.y, phot.z);
                    Point3.sub(ppos, p, pvec);
                    double pcos = Vector3.dot(pvec, n);
                    if ((pcos < maxNDist) && (pcos > -maxNDist))
                        irr.add(pow.setRGBE(phot.power));
                }
            }
            irr.mul(invArea);
            curr.irradiance = irr.toRGBE();
            temp[i] = curr;
        }

        // resize photon map to only include irradiance photons
        gatherNum /= 4;
        maxRadius = 1.4 * Math.sqrt(maxPower * gatherNum);
        if (gatherRadius > maxRadius)
            gatherRadius = maxRadius;
        storedPhotons = quadStoredPhotons;
        halfStoredPhotons = storedPhotons / 2;
        int log2n = (int) Math.ceil(Math.log(storedPhotons) / Math.log(2.0));
        dist1d2 = new float[log2n];
        chosen = new int[log2n];
        photons = temp;
        hasIrradiance = true;
    }

    public Color getIrradiance(Point3 p, Vector3 n) {
        if (!hasIrradiance || (size() == 0))
            return new Color(Color.BLACK);
        float px = (float) p.x;
        float py = (float) p.y;
        float pz = (float) p.z;
        int i = 1;
        int level = 0;
        int cameFrom;
        double dist2;
        double maxDist2 = gatherRadius * gatherRadius;
        IrradiancePhoton nearest = null;
        IrradiancePhoton curr;
        Vector3 photN = new Vector3();
        while (true) {
            while (i < halfStoredPhotons) {
                float dist1d = photons[i].getDist1(px, py, pz);
                dist1d2[level] = dist1d * dist1d;
                i += i;
                if (dist1d > 0.0f)
                    i++;
                chosen[level++] = i;
            }
            curr = (IrradiancePhoton) photons[i];
            dist2 = curr.getDist2(px, py, pz);
            if (dist2 < maxDist2) {
                getUnitVector(curr.normalTheta, curr.normalPhi, photN);
                double currentDotN = Vector3.dot(photN, n);
                if (currentDotN > 0.9) {
                    nearest = curr;
                    maxDist2 = dist2;
                }
            }
            do {
                cameFrom = i;
                i >>= 1;
                level--;
                if (i == 0)
                    return (nearest == null) ? new Color(Color.BLACK) : new Color().setRGBE(nearest.irradiance);
            } while ((dist1d2[level] >= maxDist2) || (cameFrom != chosen[level]));
            curr = (IrradiancePhoton) photons[i];
            dist2 = curr.getDist2(px, py, pz);
            if (dist2 < maxDist2) {
                getUnitVector(curr.normalTheta, curr.normalPhi, photN);
                double currentDotN = Vector3.dot(photN, n);
                if (currentDotN > 0.9) {
                    nearest = curr;
                    maxDist2 = dist2;
                }
            }
            i = chosen[level++] ^ 1;
        }
    }

    private static final class IrradiancePhoton extends Photon {
        private static final int DIRECT_FLAG = 0x4;
        private static final int CAUSTIC_FLAG = 0x8;
        byte normalPhi;
        byte normalTheta;
        int irradiance;

        IrradiancePhoton(Point3 p, Vector3 n, Vector3 dir, Color power, boolean isDirect, boolean isCaustic) {
            super(p, dir, power);
            normalPhi = getVectorPhi(n);
            normalTheta = getVectorTheta(n);
            irradiance = 0;
            if (isDirect)
                flags |= DIRECT_FLAG;
            if (isCaustic)
                flags |= CAUSTIC_FLAG;
        }

        boolean isDirect() {
            return (flags & DIRECT_FLAG) != 0;
        }

        boolean isCaustic() {
            return (flags & CAUSTIC_FLAG) != 0;
        }
    }
}