package org.sunflow.raytracer.light;

import org.sunflow.image.Color;
import org.sunflow.math.OrthoNormalBasis;
import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.LightSample;
import org.sunflow.raytracer.LightSource;
import org.sunflow.raytracer.Ray;
import org.sunflow.raytracer.RenderState;
import org.sunflow.raytracer.Shader;
import org.sunflow.raytracer.Vertex;
import org.sunflow.raytracer.geometry.Triangle;

public class TriangleAreaLight extends Triangle implements LightSource {
    private double area;
    private Color radiance;

    public TriangleAreaLight(Vertex v0, Vertex v1, Vertex v2, final Color radiance) {
        super(new Shader() {
                public Color getRadiance(RenderState state) {
                    // emit constant radiance
                    return state.includeLights() ? new Color(radiance) : new Color(Color.BLACK);
                }

                public void scatterPhoton(RenderState state, Color power) {
                    // do not scatter photons
                }
            }, v0, v1, v2);
        this.radiance = radiance;
        area = 0.5 * Vector3.cross(Point3.sub(v1.p, v0.p, new Vector3()), Point3.sub(v2.p, v0.p, new Vector3()), new Vector3()).length();
    }

    public boolean isVisible(RenderState state) {
        Point3 p = state.getVertex().p;
        Vector3 n = state.getVertex().n;
        Vector3 sub = new Vector3();
        Point3.sub(v0.p, p, sub);
        if ((Vector3.dot(sub, n) > 0.0) || (Vector3.dot(sub, ng) < 0.0))
            return true;
        Point3.sub(v1.p, p, sub);
        if ((Vector3.dot(sub, n) > 0.0) || (Vector3.dot(sub, ng) < 0.0))
            return true;
        Point3.sub(v2.p, p, sub);
        if ((Vector3.dot(sub, n) > 0.0) || (Vector3.dot(sub, ng) < 0.0))
            return true;
        return false;
    }

    public void getSample(double randX, double randY, RenderState state, LightSample dest) {
        double s = Math.sqrt(1.0 - randX);
        double u = randY * s;
        double v = 1.0 - s;

        dest.getVertex().p.x = v0.p.x + (u * (v1.p.x - v0.p.x)) + (v * (v2.p.x - v0.p.x));
        dest.getVertex().p.y = v0.p.y + (u * (v1.p.y - v0.p.y)) + (v * (v2.p.y - v0.p.y));
        dest.getVertex().p.z = v0.p.z + (u * (v1.p.z - v0.p.z)) + (v * (v2.p.z - v0.p.z));
        dest.getVertex().n.set(ng);
        dest.getVertex().tex.x = v0.tex.x + (u * (v1.tex.x - v0.tex.x)) + (v * (v2.tex.x - v0.tex.x));
        dest.getVertex().tex.y = v0.tex.y + (u * (v1.tex.y - v0.tex.y)) + (v * (v2.tex.y - v0.tex.y));

        Point3.sub(dest.getVertex().p, state.getVertex().p, dest.getDirection());
        dest.getDirection().normalize();

        // check that the direction of the sample is the same as the normal
        double cosNx = Vector3.dot(state.getVertex().n, dest.getDirection());
        if (cosNx <= 0.0) {
            dest.setValid(false);
            return;
        }

        // light source facing point ?
        // (need to check with light source's normal)
        double cosNy = -Vector3.dot(dest.getVertex().n, dest.getDirection());
        if (cosNy > 0.0) {
            dest.setValid(true);
            // prepare shadow ray
            dest.setShadowRay(new Ray(state.getVertex().p, dest.getVertex().p));
            // check to see if the geometric normal is pointing away from the light
            double cosNg = Vector3.dot(state.getGeoNormal(), dest.getDirection());
            if (cosNg < 0.0)
                // potential shadow problem
                // need to fix threshold on ray to avoid clipping
                dest.getShadowRay().setMin(0.3);

            double g = cosNy / state.getVertex().p.distanceToSquared(dest.getVertex().p);
            Color.mul(g * area, radiance, dest.getRadiance());
        } else
            dest.setValid(false);
    }

    public void getPhoton(double randX1, double randY1, double randX2, double randY2, Point3 p, Vector3 dir, Color power) {
        double s = Math.sqrt(1.0 - randX1);
        double u = randY1 * s;
        double v = 1.0 - s;
        p.x = v0.p.x + (u * (v1.p.x - v0.p.x)) + (v * (v2.p.x - v0.p.x));
        p.y = v0.p.y + (u * (v1.p.y - v0.p.y)) + (v * (v2.p.y - v0.p.y));
        p.z = v0.p.z + (u * (v1.p.z - v0.p.z)) + (v * (v2.p.z - v0.p.z));
        OrthoNormalBasis onb = OrthoNormalBasis.makeFromW(ng);
        u = 2 * Math.PI * randX2;
        s = Math.sqrt(randY2);
        Vector3 w = new Vector3(Math.cos(u) * s, Math.sin(u) * s, Math.sqrt(1.0 - randY2));
        onb.transform(w, dir);
        Color.mul(Math.PI * area, radiance, power);
    }

    public double getAveragePower() {
        return Math.PI * area * radiance.getAverage();
    }
}