package org.sunflow.raytracer.photonmap;

import org.sunflow.image.Bitmap;
import org.sunflow.image.Color;
import org.sunflow.math.Point2;
import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.BoundingBox;
import org.sunflow.raytracer.Camera;

class PhotonMap {
    private static final double[] COS_THETA = new double[256];
    private static final double[] SIN_THETA = new double[256];
    private static final double[] COS_PHI = new double[256];
    private static final double[] SIN_PHI = new double[256];
    Photon[] photons;
    int storedPhotons;
    int halfStoredPhotons;
    float[] dist1d2;
    int[] chosen;
    int gatherNum;
    double gatherRadius;
    private BoundingBox bounds;

    static {
        // precompute tables to compress unit vectors
        for (int i = 0; i < 256; i++) {
            double angle = (i * Math.PI) / 256.0;
            COS_THETA[i] = Math.cos(angle);
            SIN_THETA[i] = Math.sin(angle);
            COS_PHI[i] = Math.cos(2.0 * angle);
            SIN_PHI[i] = Math.sin(2.0 * angle);
        }
    }

    PhotonMap(int maxPhotons, int gatherNum, double gatherRadius) {
        photons = new Photon[maxPhotons + 1];
        storedPhotons = halfStoredPhotons = 0;
        dist1d2 = null;
        chosen = null;
        this.gatherNum = gatherNum;
        this.gatherRadius = gatherRadius;
        bounds = new BoundingBox();
    }

    final void storePhoton(Photon p) {
        if (storedPhotons >= (photons.length - 1))
            return;
        storedPhotons++;
        photons[storedPhotons] = p;
        bounds.include(new Point3(p.x, p.y, p.z));
    }

    public final int size() {
        return storedPhotons;
    }

    public final int maxSize() {
        return photons.length - 1;
    }

    public final boolean isFull() {
        return storedPhotons == (photons.length - 1);
    }

    final void scalePhotonPower(double scale) {
        Color c = new Color();
        for (int i = 1; i <= storedPhotons; i++)
            photons[i].power = c.setRGBE(photons[i].power).mul(scale).toRGBE();
    }

    public final void display(Camera cam, String filename) {
        Bitmap img = new Bitmap(cam.getImageWidth(), cam.getImageHeight(), true);
        for (int i = 1; i <= storedPhotons; i++) {
            Point2 p = cam.getPoint(new Point3(photons[i].x, photons[i].y, photons[i].z));
            if (p != null)
                img.setPixel((int) p.x, (int) p.y, new Color().setRGBE(photons[i].power));
        }
        img.save(filename);
    }

    final void locatePhotons(NearestPhotons np) {
        int i = 1;
        int level = 0;
        int cameFrom;
        while (true) {
            while (i < halfStoredPhotons) {
                float dist1d = photons[i].getDist1(np.px, np.py, np.pz);
                dist1d2[level] = dist1d * dist1d;
                i += i;
                if (dist1d > 0.0f)
                    i++;
                chosen[level++] = i;
            }
            np.checkAddNearest(photons[i]);
            do {
                cameFrom = i;
                i >>= 1;
                level--;
                if (i == 0)
                    return;
            } while ((dist1d2[level] >= np.dist2[0]) || (cameFrom != chosen[level]));
            np.checkAddNearest(photons[i]);
            i = chosen[level++] ^ 1;
        }
    }

    final void balance() {
        Photon[] temp = new Photon[storedPhotons + 1];
        balanceSegment(temp, 1, 1, storedPhotons);
        photons = temp;
        halfStoredPhotons = storedPhotons / 2;
        int log2n = (int) Math.ceil(Math.log(storedPhotons) / Math.log(2.0));
        dist1d2 = new float[log2n];
        chosen = new int[log2n];
    }

    private void balanceSegment(Photon[] temp, int index, int start, int end) {
        int median = 1;
        while ((4 * median) <= (end - start + 1))
            median += median;
        if ((3 * median) <= (end - start + 1)) {
            median += median;
            median += (start - 1);
        } else
            median = end - median + 1;
        int axis = Photon.SPLIT_Z;
        Vector3 extents = bounds.getExtents();
        if ((extents.x > extents.y) && (extents.x > extents.z))
            axis = Photon.SPLIT_X;
        else if (extents.y > extents.z)
            axis = Photon.SPLIT_Y;
        int left = start;
        int right = end;
        while (right > left) {
            double v = photons[right].getCoord(axis);
            int i = left - 1;
            int j = right;
            while (true) {
                while (photons[++i].getCoord(axis) < v) {}
                while ((photons[--j].getCoord(axis) > v) && (j > left)) {}
                if (i >= j)
                    break;
                swap(i, j);
            }
            swap(i, right);
            if (i >= median)
                right = i - 1;
            if (i <= median)
                left = i + 1;
        }
        temp[index] = photons[median];
        temp[index].setSplitAxis(axis);
        if (median > start) {
            if (start < (median - 1)) {
                double tmp;
                switch (axis) {
                    case Photon.SPLIT_X:
                        tmp = bounds.getMaximum().x;
                        bounds.getMaximum().x = temp[index].x;
                        balanceSegment(temp, 2 * index, start, median - 1);
                        bounds.getMaximum().x = tmp;
                        break;
                    case Photon.SPLIT_Y:
                        tmp = bounds.getMaximum().y;
                        bounds.getMaximum().y = temp[index].y;
                        balanceSegment(temp, 2 * index, start, median - 1);
                        bounds.getMaximum().y = tmp;
                        break;
                    default:
                        tmp = bounds.getMaximum().z;
                        bounds.getMaximum().z = temp[index].z;
                        balanceSegment(temp, 2 * index, start, median - 1);
                        bounds.getMaximum().z = tmp;
                }
            } else
                temp[2 * index] = photons[start];
        }
        if (median < end) {
            if ((median + 1) < end) {
                double tmp;
                switch (axis) {
                    case Photon.SPLIT_X:
                        tmp = bounds.getMinimum().x;
                        bounds.getMinimum().x = temp[index].x;
                        balanceSegment(temp, (2 * index) + 1, median + 1, end);
                        bounds.getMinimum().x = tmp;
                        break;
                    case Photon.SPLIT_Y:
                        tmp = bounds.getMinimum().y;
                        bounds.getMinimum().y = temp[index].y;
                        balanceSegment(temp, (2 * index) + 1, median + 1, end);
                        bounds.getMinimum().y = tmp;
                        break;
                    default:
                        tmp = bounds.getMinimum().z;
                        bounds.getMinimum().z = temp[index].z;
                        balanceSegment(temp, (2 * index) + 1, median + 1, end);
                        bounds.getMinimum().z = tmp;
                }
            } else
                temp[(2 * index) + 1] = photons[end];
        }
    }

    private void swap(int i, int j) {
        Photon tmp = photons[i];
        photons[i] = photons[j];
        photons[j] = tmp;
    }

    static final void getUnitVector(byte theta, byte phi, Vector3 dest) {
        int t = theta & 0xFF;
        int p = phi & 0xFF;
        dest.x = SIN_THETA[t] * COS_PHI[p];
        dest.y = SIN_THETA[t] * SIN_PHI[p];
        dest.z = COS_THETA[t];
    }

    static final byte getVectorTheta(Vector3 v) {
        return (byte) (Math.acos(v.z) * (256.0 / Math.PI));
    }

    static final byte getVectorPhi(Vector3 v) {
        int phi = (int) (Math.atan2(v.y, v.x) * (128.0 / Math.PI));
        return (byte) ((phi < 0) ? (phi + 256) : phi);
    }

    static class Photon {
        float x;
        float y;
        float z;
        byte dirPhi;
        byte dirTheta;
        int power;
        int flags;
        static final int SPLIT_X = 0;
        static final int SPLIT_Y = 1;
        static final int SPLIT_Z = 2;
        static final int SPLIT_MASK = 3;

        Photon(Point3 p, Vector3 dir, Color power) {
            x = (float) p.x;
            y = (float) p.y;
            z = (float) p.z;
            dirPhi = getVectorPhi(dir);
            dirTheta = getVectorTheta(dir);
            this.power = power.toRGBE();
            flags = SPLIT_X;
        }

        void setSplitAxis(int axis) {
            flags &= ~SPLIT_MASK;
            flags |= axis;
        }

        float getCoord(int axis) {
            switch (axis) {
                case SPLIT_X:
                    return x;
                case SPLIT_Y:
                    return y;
                default:
                    return z;
            }
        }

        float getDist1(float px, float py, float pz) {
            switch (flags & SPLIT_MASK) {
                case SPLIT_X:
                    return px - x;
                case SPLIT_Y:
                    return py - y;
                default:
                    return pz - z;
            }
        }

        float getDist2(float px, float py, float pz) {
            float dx = x - px;
            float dy = y - py;
            float dz = z - pz;
            return (dx * dx) + (dy * dy) + (dz * dz);
        }
    }
}