package org.sunflow.raytracer.photonmap;

import org.sunflow.image.Color;
import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.LightSample;
import org.sunflow.raytracer.RenderState;

public final class CausticPhotonMap extends PhotonMap {
    private double filterValue;
    private double maxPower;
    private double maxRadius;

    public CausticPhotonMap(int maxPhotons, int gatherNum, double gatherRadius, double filterValue) {
        super(maxPhotons, gatherNum, gatherRadius);
        this.filterValue = filterValue;
        maxPower = 0;
        maxRadius = 0;
    }

    public void storePhoton(RenderState state, Vector3 dir, Color power) {
        super.storePhoton(new Photon(state.getVertex().p, dir, power));
        maxPower = Math.max(maxPower, power.getMax());
    }

    public void initialize(double scale) {
        balance();
        scalePhotonPower(scale);
        maxPower *= scale;
        maxRadius = 1.4 * Math.sqrt(maxPower * gatherNum);
        if (gatherRadius > maxRadius)
            gatherRadius = maxRadius;
    }

    public void getNearestPhotons(RenderState state) {
        if (size() == 0)
            return;
        NearestPhotons np = new NearestPhotons(state.getVertex().p, gatherNum, gatherRadius * gatherRadius);
        locatePhotons(np);
        if (np.found < 8)
            return;
        Point3 ppos = new Point3();
        Vector3 pdir = new Vector3();
        Vector3 pvec = new Vector3();
        double invArea = 1.0 / (Math.PI * np.dist2[0]);
        double maxNDist = np.dist2[0] * 0.05;
        double f2r2 = 1.0 / (filterValue * filterValue * np.dist2[0]);
        double fInv = 1.0 / (1.0 - 2.0 / (3.0 * filterValue));
        for (int i = 1; i <= np.found; i++) {
            Photon phot = np.index[i];
            getUnitVector(phot.dirTheta, phot.dirPhi, pdir);
            double cos = -Vector3.dot(pdir, state.getVertex().n);
            if (cos > 0.01) {
                ppos.set(phot.x, phot.y, phot.z);
                Point3.sub(ppos, state.getVertex().p, pvec);
                double pcos = Vector3.dot(pvec, state.getVertex().n);
                if ((pcos < maxNDist) && (pcos > -maxNDist)) {
                    LightSample sample = new LightSample();
                    pdir.negate(sample.getDirection());
                    sample.getRadiance().setRGBE(np.index[i].power).mul(invArea / cos);
                    sample.getRadiance().mul((1.0 - Math.sqrt(np.dist2[i] * f2r2)) * fInv);
                    sample.setShadowRay(null);
                    sample.setValid(true);
                    state.addSample(sample);
                }
            }
        }
    }
}