// File:      csg.cpp
// Author:    Philip Rideout
// Brief:     pbrt plugin to perform CSG intersection
//            An optional list of inversions can be specified.
//            For example: Difference(A, B) = Intersection(A, Inversion(B))

#include "pbrt.h"
#include "primitive.h"


class  CsgAccel : public Aggregate
{
  public:
    CsgAccel(const vector<Reference<Primitive> > &p, bool refineImmediately, const bool* inversions);
    BBox WorldBound() const;
    bool CanIntersect() const { return true; }
    bool Intersect(const Ray &ray, Intersection *isect) const;
    bool IntersectP(const Ray &ray) const;
  private:
    vector<Reference<Primitive> > prims;
    vector<bool> inversions;
    BBox bounds;
};


BBox CsgAccel::WorldBound() const { return bounds; }


CsgAccel::CsgAccel(const vector<Reference<Primitive> > &p, bool refineImmediately, const bool* inversions)
{
    if (refineImmediately)
        for (u_int i = 0; i < p.size(); ++i)
            p[i]->FullyRefine(prims);
    else
        prims = p;

    for (u_int i = 0; i < prims.size(); ++i)
        bounds = Union(bounds, prims[i]->WorldBound());

    this->inversions.resize(p.size());
    if (inversions)
        for (u_int i = 0; i < prims.size(); ++i)
            this->inversions[i] = inversions[i];
    else
        for (u_int i = 0; i < prims.size(); ++i)
            this->inversions[i] = false;
}


bool CsgAccel::Intersect(const Ray &ray, Intersection *isect) const
{
    float rayT;
    if (bounds.Inside(ray(ray.mint)))
        rayT = ray.mint;
    else if (!bounds.IntersectP(ray, &rayT))
        return false;

    bool hitSomething = true;
    for (u_int i = 0; i < prims.size(); ++i)
    {
        hitSomething = hitSomething && (inversions[i] != prims[i]->Intersect(ray, isect));
        ray.maxt = INFINITY;
        ray.mint = RAY_EPSILON;
    }

    return hitSomething;
}


bool CsgAccel::IntersectP(const Ray &ray) const
{
    float rayT;
    if (bounds.Inside(ray(ray.mint)))
        rayT = ray.mint;
    else if (!bounds.IntersectP(ray, &rayT))
        return false;

    for (u_int i = 0; i < prims.size(); ++i)
        if (inversions[i] == prims[i]->IntersectP(ray))
            return false;

    return true;
}


extern "C" DLLEXPORT Primitive *CreateAccelerator(const vector<Reference<Primitive> >& prims, const ParamSet& ps)
{
    bool refineImmediately = ps.FindOneBool("refineimmediately", false);
    
    int nInversions = 0;
    const bool* inversions = ps.FindBool("inversions", &nInversions);
    if (inversions && nInversions != prims.size())
        Severe("number of inversions is not equal to the number of primitives in the object group");

    return new CsgAccel(prims, refineImmediately, inversions);
}
