package org.sunflow.raytracer.shader;

import org.sunflow.image.Color;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.Ray;
import org.sunflow.raytracer.RenderState;
import org.sunflow.raytracer.Shader;

public class MirrorShader extends Shader {
    private Color reflect;

    public MirrorShader(Color reflect) {
        this.reflect = new Color(reflect);
    }

    public Color getRadiance(RenderState state) {
        if (!state.includeSpecular())
            return new Color(Color.BLACK);
        double cos = -Vector3.dot(state.getVertex().n, state.getRay().getDirection());
        if (cos < 0.0) {
            state.getVertex().n.negate();
            cos = -cos;
        }
        double dn = 2.0 * cos;
        Vector3 refDir = new Vector3();
        refDir.x = (dn * state.getVertex().n.x) + state.getRay().getDirection().x;
        refDir.y = (dn * state.getVertex().n.y) + state.getRay().getDirection().y;
        refDir.z = (dn * state.getVertex().n.z) + state.getRay().getDirection().z;
        Ray refRay = new Ray(state.getVertex().p, refDir);

        // compute Fresnel term
        cos = 1.0 - cos;
        float cos2 = (float) (cos * cos);
        float cos5 = (float) (cos2 * cos2 * cos);
        Color ret = new Color(Color.WHITE);
        ret.sub(reflect);
        ret.mul(cos5);
        ret.add(reflect);
        return ret.mul(traceSpecular(state, refRay));
    }

    public void scatterPhoton(RenderState state, Color power) {
        double cos = -Vector3.dot(state.getVertex().n, state.getRay().getDirection());
        if (cos < 0.0) {
            // in going out
            cos = -cos;
            state.getVertex().n.negate();
        }
        power.mul(reflect).mul(1.0 / reflect.getAverage());
        // photon is reflected
        double dn = 2.0 * cos;
        Vector3 dir = new Vector3();
        dir.x = (dn * state.getVertex().n.x) + state.getRay().getDirection().x;
        dir.y = (dn * state.getVertex().n.y) + state.getRay().getDirection().y;
        dir.z = (dn * state.getVertex().n.z) + state.getRay().getDirection().z;
        traceSpecularPhoton(state, new Ray(state.getVertex().p, dir), power);
    }
}