package org.sunflow.raytracer.shader;

import org.sunflow.image.Color;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.Ray;
import org.sunflow.raytracer.RenderState;
import org.sunflow.raytracer.Shader;

public class GlassShader extends Shader {
    private double eta; // refraction index ratio
    private double f0; // fresnel normal incidence
    private Color glassColor;

    public GlassShader(double eta, Color glassColor) {
        this.eta = eta;
        this.glassColor = glassColor;
        f0 = (1.0 - eta) / (1.0 + eta);
        f0 = f0 * f0;
    }

    public Color getRadiance(RenderState state) {
        if (!state.includeSpecular())
            return new Color(Color.BLACK);
        Vector3 reflDir = new Vector3();
        Vector3 refrDir = new Vector3();
        double cos = -Vector3.dot(state.getVertex().n, state.getRay().getDirection());
        double neta;
        boolean inside = false;
        if (cos > 0.0)
            // out going in
            neta = 1.0 / eta;
        else {
            // in going out
            cos = -cos;
            neta = eta;
            state.getVertex().n.negate();
            inside = true;
        }

        // compute Fresnel terms
        double cos1 = 1.0 - cos;
        double cos2 = cos1 * cos1;
        double fr = f0 + (cos2 * cos2 * cos1 * (1.0 - f0));
        double ft = 1.0 - fr;

        // refracted ray
        double wK = -neta;
        double nK = (neta * cos) - Math.sqrt(1.0 - (neta * neta * (1.0 - (cos * cos))));
        refrDir.x = (-wK * state.getRay().getDirection().x) + (nK * state.getVertex().n.x);
        refrDir.y = (-wK * state.getRay().getDirection().y) + (nK * state.getVertex().n.y);
        refrDir.z = (-wK * state.getRay().getDirection().z) + (nK * state.getVertex().n.z);
        Ray refrRay = new Ray(state.getVertex().p, refrDir);
        Color ret = Color.mul(ft, traceSpecular(state, refrRay)).mul(glassColor);
        if (!inside) {
            // reflected ray would just keep on bouncing inside
            // and contribute negligible energy to the final color
            // so it is safe to ignore it in order to provide a small speedup
            double dn = 2.0 * cos;
            reflDir.x = (dn * state.getVertex().n.x) + state.getRay().getDirection().x;
            reflDir.y = (dn * state.getVertex().n.y) + state.getRay().getDirection().y;
            reflDir.z = (dn * state.getVertex().n.z) + state.getRay().getDirection().z;
            Ray reflRay = new Ray(state.getVertex().p, reflDir);
            ret.add(Color.mul(fr, traceSpecular(state, reflRay)).mul(glassColor));
        }
        return ret;
    }

    public void scatterPhoton(RenderState state, Color power) {
        Color refr = Color.mul(1.0 - f0, glassColor);
        Color refl = Color.mul(f0, glassColor);
        double avgR = refl.getAverage();
        double avgT = refr.getAverage();
        double rnd = Math.random();
        if (rnd < avgR) {
            // photon is reflected
            Vector3 dir = new Vector3();
            double cos = -Vector3.dot(state.getVertex().n, state.getRay().getDirection());
            if (cos < 0.0) {
                // in going out
                cos = -cos;
                state.getVertex().n.negate();
            }
            power.mul(refl).mul(1.0 / avgR);
            double dn = 2.0 * cos;
            dir.x = (dn * state.getVertex().n.x) + state.getRay().getDirection().x;
            dir.y = (dn * state.getVertex().n.y) + state.getRay().getDirection().y;
            dir.z = (dn * state.getVertex().n.z) + state.getRay().getDirection().z;
            traceSpecularPhoton(state, new Ray(state.getVertex().p, dir), power);
        } else {
            // photon is refracted
            Vector3 dir = new Vector3();
            double cos = -Vector3.dot(state.getVertex().n, state.getRay().getDirection());
            double neta;
            if (cos > 0.0)
                // out going in
                neta = 1.0 / eta;
            else {
                // in going out
                cos = -cos;
                neta = eta;
                state.getVertex().n.negate();
            }
            power.mul(refr).mul(1.0 / avgT);
            double wK = -neta;
            double nK = (neta * cos) - Math.sqrt(1.0 - (neta * neta * (1.0 - (cos * cos))));
            dir.x = (-wK * state.getRay().getDirection().x) + (nK * state.getVertex().n.x);
            dir.y = (-wK * state.getRay().getDirection().y) + (nK * state.getVertex().n.y);
            dir.z = (-wK * state.getRay().getDirection().z) + (nK * state.getVertex().n.z);
            traceSpecularPhoton(state, new Ray(state.getVertex().p, dir), power);
        }
    }
}