package org.sunflow.raytracer.geometry;

import org.sunflow.math.Point3;
import org.sunflow.raytracer.BoundingBox;
import org.sunflow.raytracer.Intersectable;
import org.sunflow.raytracer.Ray;
import org.sunflow.raytracer.RenderState;
import org.sunflow.raytracer.Shader;

public class Sphere implements Intersectable {
    private Point3 center;
    private double r;
    private double r2;
    private Shader shader;

    public Sphere(Shader shader, Point3 center, double radius) {
        this.center = center;
        r = radius;
        r2 = r * r;
        this.shader = shader;
    }

    public BoundingBox getBounds() {
        BoundingBox bounds = new BoundingBox();
        bounds.include(new Point3(center.x - r, center.y - r, center.z - r));
        bounds.include(new Point3(center.x + r, center.y + r, center.z + r));
        return bounds;
    }

    public boolean intersects(BoundingBox box) {
        double a;
        double b;
        double dmax = 0;
        double dmin = 0;
        a = (center.x - box.getMinimum().x) * (center.x - box.getMinimum().x);
        b = (center.x - box.getMaximum().x) * (center.x - box.getMaximum().x);
        dmax += Math.max(a, b);
        if (center.x < box.getMinimum().x)
            dmin += a;
        else if (center.x > box.getMaximum().x)
            dmin += b;
        a = (center.y - box.getMinimum().y) * (center.y - box.getMinimum().y);
        b = (center.y - box.getMaximum().y) * (center.y - box.getMaximum().y);
        dmax += Math.max(a, b);
        if (center.y < box.getMinimum().y)
            dmin += a;
        else if (center.y > box.getMaximum().y)
            dmin += b;
        a = (center.z - box.getMinimum().z) * (center.z - box.getMinimum().z);
        b = (center.z - box.getMaximum().z) * (center.z - box.getMaximum().z);
        dmax += Math.max(a, b);
        if (center.z < box.getMinimum().z)
            dmin += a;
        else if (center.z > box.getMaximum().z)
            dmin += b;
        return ((dmin <= r2) && (r2 <= dmax));
    }

    public Shader getSurfaceShader() {
        return shader;
    }

    public void setSurfaceLocation(RenderState state) {
        state.getRay().getPoint(state.getT(), state.getVertex().p);
        Point3.sub(state.getVertex().p, center, state.getVertex().n);
        state.getVertex().n.normalize();
        state.getVertex().tex.y = Math.acos(state.getVertex().n.z) / Math.PI;
        if (state.getVertex().n.y >= 0.0)
            state.getVertex().tex.x = Math.acos(state.getVertex().n.x / Math.sin(Math.PI * state.getVertex().tex.y)) / (2.0 * Math.PI);
        else
            state.getVertex().tex.x = (Math.PI + Math.acos(state.getVertex().n.x / Math.sin(Math.PI * state.getVertex().tex.y))) / (2.0 * Math.PI);
        state.getGeoNormal().set(state.getVertex().n);
    }

    public void intersect(RenderState state) {
        Ray r = state.getRay();
        double ocx = r.getOrigin().x - center.x;
        double ocy = r.getOrigin().y - center.y;
        double ocz = r.getOrigin().z - center.z;
        double qb = (r.getDirection().x * ocx) + (r.getDirection().y * ocy) + (r.getDirection().z * ocz);
        double qc = ((ocx * ocx) + (ocy * ocy) + (ocz * ocz)) - r2;
        double det = (qb * qb) - qc;
        if (det > 0.0) {
            det = Math.sqrt(det);
            double t = -det - qb;
            if (r.isInside(t)) {
                r.setMax(t);
                state.setIntersection(this, t, 0.0, 0.0);
            } else {
                t = det - qb;
                if (r.isInside(t)) {
                    r.setMax(t);
                    state.setIntersection(this, t, 0.0, 0.0);
                }
            }
        }
    }

    public boolean intersects(Ray r) {
        double ocx = r.getOrigin().x - center.x;
        double ocy = r.getOrigin().y - center.y;
        double ocz = r.getOrigin().z - center.z;
        double qb = (r.getDirection().x * ocx) + (r.getDirection().y * ocy) + (r.getDirection().z * ocz);
        double qc = ((ocx * ocx) + (ocy * ocy) + (ocz * ocz)) - r2;
        double det = (qb * qb) - qc;
        if (det > 0.0) {
            det = Math.sqrt(det);
            if (r.isInside(-det - qb))
                return true;
            return r.isInside(det - qb);
        }
        return false;
    }
}