package org.sunflow.raytracer.geometry;

import org.sunflow.math.Point3;
import org.sunflow.math.Vector3;
import org.sunflow.raytracer.BoundingBox;
import org.sunflow.raytracer.Intersectable;
import org.sunflow.raytracer.Ray;
import org.sunflow.raytracer.RenderState;
import org.sunflow.raytracer.Shader;

public class Plane implements Intersectable {
    private Point3 center;
    private Vector3 normal;
    private Shader shader;

    public Plane(Shader shader, Point3 center, Vector3 normal) {
        this.center = center;
        this.normal = new Vector3(normal).normalize();
        this.shader = shader;
    }

    public BoundingBox getBounds() {
        return null;
    }

    public boolean intersects(BoundingBox box) {
        double t1 = (((center.x - box.getMinimum().x) * normal.x) + ((center.y - box.getMinimum().y) * normal.y) + ((center.z - box.getMinimum().z) * normal.z));
        double t2 = (((center.x - box.getMaximum().x) * normal.x) + ((center.y - box.getMaximum().y) * normal.y) + ((center.z - box.getMaximum().z) * normal.z));
        return ((t1 * t2) <= 0.0);
    }

    public Shader getSurfaceShader() {
        return shader;
    }

    public void setSurfaceLocation(RenderState state) {
        state.getRay().getPoint(state.getT(), state.getVertex().p);
        state.getVertex().n.set(normal);
        state.getVertex().tex.x = 0.0;
        state.getVertex().tex.y = 0.0;
        state.getGeoNormal().set(normal);
    }

    public void intersect(RenderState state) {
        Ray r = state.getRay();
        double dn = Vector3.dot(r.getDirection(), normal);
        if (dn == 0.0)
            return;
        double t = (((center.x - r.getOrigin().x) * normal.x) + ((center.y - r.getOrigin().y) * normal.y) + ((center.z - r.getOrigin().z) * normal.z)) / dn;
        if (r.isInside(t)) {
            r.setMax(t);
            state.setIntersection(this, t, 0.0, 0.0);
        }
    }

    public boolean intersects(Ray r) {
        double dn = Vector3.dot(r.getDirection(), normal);
        if (dn == 0.0)
            return false;
        double t = (((center.x - r.getOrigin().x) * normal.x) + ((center.y - r.getOrigin().y) * normal.y) + ((center.z - r.getOrigin().z) * normal.z)) / dn;
        return r.isInside(t);
    }
}