// gridimplicit.cpp
//
// Grid-based implicit surface shape plugin for pbrt.
// Note this does *not* play well with texturing - u=v=0 everywhere.
// Use it with a pbrt line like
//    Shape "gridimplicit" "string filename" "myimplicitdata"
// where the actual data is in a file called myimplicitdata.
// The format of that file is three integers for the x, y, and z
// dimensions of the grid, then all the sample values in order
// (x index changing fast, z index changing slowest)
// Note that it will clip any surface contained in the outer-most
// layer of grid cells (voxels).
//
// TODO: Eliminate the extra file and integrate the data into the "Shape"
//       call, at least optionally.
//       Change the representation to an octree or something else faster
//       when first constructed.
//
// Robert Bridson -- January 10,2005
// Use freely but at your own risk.

#include "shape.h"
#include "paramset.h"

class GridImplicit: public Shape
{
   int dim[3];           // number of grid points (samples) in each direction
   int celldim[3];       // number of grid cells (one less than samples)
   BBox bounds;          // object-space extent (0,0,0 to dim[0,1,2])
   float *phi;           // actual data samples
   char *surface_marker; // bit array marking cells that contain isosurface

   public:

   GridImplicit(const Transform &o2w, bool ro, const char *filename);
   ~GridImplicit() { delete[] phi; delete[] surface_marker; }
   BBox ObjectBound() const { return bounds; }
   bool Intersect(const Ray &ray, float *tHit, DifferentialGeometry *dg) const;
   bool IntersectP(const Ray &ray) const;
   void GetShadingGeometry(const Transform &obj2world,
                           const DifferentialGeometry &dg,
                           DifferentialGeometry *dgShading) const;

   private:

   bool IntersectCore(const Ray &ray, float *tHit=0,
                      DifferentialGeometry *dg=0) const;
   bool IntersectCell(const Ray &ray, const int pos[3], float t0, float t1,
                      float *tHit, DifferentialGeometry *dg) const;

   float &val(int i, int j, int k)
   { return phi[i+dim[0]*(j+dim[1]*k)]; }

   const float &val(int i, int j, int k) const
   { return phi[i+dim[0]*(j+dim[1]*k)]; }

   bool get_surface_marker(int i, int j, int k) const
   {
      int index=i+celldim[0]*(j+celldim[1]*k);
      return surface_marker[index/8] & 1<<(index%8);
   }

   void set_surface_marker(int i, int j, int k)
   {
      int index=i+celldim[0]*(j+celldim[1]*k);
      surface_marker[index/8] |= 1<<(index%8);
   }

};

#define MAXDIMENSION 1000

GridImplicit::
GridImplicit(const Transform &o2w, bool ro, const char *filename)
   : Shape(o2w, reverseOrientation), phi(0), surface_marker(0)
{
   dim[0]=dim[1]=dim[2]=0;
   celldim[0]=celldim[1]=celldim[2]=-1;

   int i, j, k;
   FILE *fp=fopen(filename, "r");
   if(!fp){
      Error("Couldn't open gridimplicit file \"%s\" for reading\n", filename);
      return;
   }

   // read in dimensions
   if(3!=fscanf(fp, "%d %d %d", dim, dim+1, dim+2)){
      Error("Problem reading gridimplicit file \"%s\" (dimensions)\n", filename);
      dim[0]=dim[1]=dim[2]=0;
      return;
   }
   if(dim[0]<=1 || dim[1]<=1 || dim[2]<=1){
      Error("Bad dimensions (%d, %d, %d) in gridimplicit file \"%s\"\n", dim[0], dim[1], dim[2], filename);
      dim[0]=dim[1]=dim[2]=0;
      return;
   }
   if(dim[0]>MAXDIMENSION || dim[1]>MAXDIMENSION || dim[2]>MAXDIMENSION){
      Warning("Suspicious dimensions (%d %d %d) in gridimplicit file \"%s\"\n", dim[0], dim[1], dim[2], filename);
   }

   // read in samples
   phi=new float[dim[0]*dim[1]*dim[2]];
   for(i=0; i<dim[0]*dim[1]*dim[2]; ++i){
      if(1!=fscanf(fp, "%f", phi+i)){
         Error("Problem reading gridimplicit file \"%s\" (at value %d)\n", filename, i);
         dim[0]=dim[1]=dim[2]=0;
         delete[] phi;
         phi=0;
         return;
      }
      Assert(phi[i]==phi[i]);
   }

   // finished with the file
   fclose(fp);

   // fill in other fields
   celldim[0]=dim[0]-1;
   celldim[1]=dim[1]-1;
   celldim[2]=dim[2]-1;
   bounds.pMin=Point(0.f,0.f,0.f);
   bounds.pMax=Point(dim[0],dim[1],dim[2]);
 
   // set up surface marker array
   unsigned int numcells=celldim[0]*celldim[1]*celldim[2];
   surface_marker=new char[(numcells+7)/8];
   memset(surface_marker, 0, (numcells+7)/8);
   for(k=0; k<celldim[2]; ++k){
      for(j=0; j<celldim[1]; ++j){
         for(i=0; i<celldim[0]; ++i){
	    // check if this cell has a sign-change
	    if((val(i,j,k)<0 && val(i+1,j,k)<0 && val(i,j+1,k)<0 && val(i+1,j+1,k)<0 && val(i,j,k+1)<0 && val(i+1,j,k+1)<0 && val(i,j+1,k+1)<0 && val(i+1,j+1,k+1)<0)
	       || (val(i,j,k)>0 && val(i+1,j,k)>0 && val(i,j+1,k)>0 && val(i+1,j+1,k)>0 && val(i,j,k+1)>0 && val(i+1,j,k+1)>0 && val(i,j+1,k+1)>0 && val(i+1,j+1,k+1)>0))
	       ; /* do nothing -- already zero */
	    else
               set_surface_marker(i,j,k);
         }
      }
   }
}

bool GridImplicit::
Intersect(const Ray &r, float *tHit, DifferentialGeometry *dg) const
{
   Ray ray;
   WorldToObject(r, &ray);
   return IntersectCore(ray, tHit, dg);
}

bool GridImplicit::
IntersectP(const Ray &r) const
{
   Ray ray;
   WorldToObject(r, &ray);
   return IntersectCore(ray);
}

bool GridImplicit::
IntersectCore(const Ray &ray, float *tHit, DifferentialGeometry *dg) const
{
   float rayT;
   if(!bounds.IntersectP(ray, &rayT))
      return false; // ray missed the grid altogether

   // Set up 3D DDA for ray
   Point gridIntersect=ray(rayT);
   float nextCrossingT[3], deltaT[3];
   int pos[3], step[3], out[3];
   for(int axis=0; axis<3; ++axis){
      pos[axis]=Float2Int(gridIntersect[axis]);
      // check for rounding error sending us outside the grid
      if(pos[axis]<0) pos[axis]=0;
      else if(pos[axis]>=celldim[axis]) pos[axis]=celldim[axis]-1;
      if(ray.d[axis]<0){
         // Handle ray with negative direction
         nextCrossingT[axis]=rayT+(pos[axis]-gridIntersect[axis])/ray.d[axis];
         deltaT[axis]=-1.f/ray.d[axis];
         step[axis]=-1;
         out[axis]=-1;
      }else{
         // Handle ray with non-negative direction
         nextCrossingT[axis]=rayT+(pos[axis]+1-gridIntersect[axis])/ray.d[axis];
         deltaT[axis]=1.f/ray.d[axis];
         step[axis]=1;
         out[axis]=celldim[axis];
      }
   }

   // Walk ray through grid
   for(;;){
      // find stepAxis for stepping to next cell
      int stepAxis;
      if(nextCrossingT[0]<nextCrossingT[2]){
	 if(nextCrossingT[0]<nextCrossingT[1]) stepAxis=0;
	 else stepAxis=1;
      }else if(nextCrossingT[2]<nextCrossingT[1]) stepAxis=2;
      else stepAxis=1;
      // check this cell
      if(get_surface_marker(pos[0],pos[1],pos[2]) &&
         IntersectCell(ray, pos, rayT,
                       fminf(ray.maxt, nextCrossingT[stepAxis]), tHit, dg))
         return true;
      if(ray.maxt<nextCrossingT[stepAxis])
         break;
      pos[stepAxis]+=step[stepAxis];
      if(pos[stepAxis]==out[stepAxis])
         break;
      rayT=nextCrossingT[stepAxis];
      nextCrossingT[stepAxis]+=deltaT[stepAxis];
   }
   return false;
}

static inline double
eval(const double v[8], double xfrac, double yfrac, double zfrac)
{
   float xrest=1-xfrac, yrest=1-yfrac, zrest=1-zfrac;
   return xrest*( yrest*( zrest*v[0] + zfrac*v[1] ) +
                  yfrac*( zrest*v[2] + zfrac*v[3] ) ) +
          xfrac*( yrest*( zrest*v[4] + zfrac*v[5] ) +
                  yfrac*( zrest*v[6] + zfrac*v[7] ) );
}

// given a unit normal nn, fill in a and b to create a tangent-space basis
static void
MakeTangentBasis(const Normal &nn, Vector &a, Vector &b)
{
   if(nn.x || nn.y){
      a=Vector(nn.y, -nn.x, 0);
      a/=a.Length();
      b=Cross(nn,a);
   }else{
      a=Vector(1,0,0);
      b=Vector(0,1,0);
   }
}

bool GridImplicit::
IntersectCell(const Ray &ray, const int pos[3], float t0, float t1,
              float *tHit, DifferentialGeometry *dg) const
{
   Assert(pos[0]>=0 && pos[0]<celldim[0]);
   Assert(pos[1]>=0 && pos[1]<celldim[1]);
   Assert(pos[2]>=0 && pos[2]<celldim[2]);
   const double v[8]={val(pos[0],pos[1],pos[2]), val(pos[0],pos[1],pos[2]+1),
      val(pos[0],pos[1]+1,pos[2]), val(pos[0],pos[1]+1,pos[2]+1),
      val(pos[0]+1,pos[1],pos[2]), val(pos[0]+1,pos[1],pos[2]+1),
      val(pos[0]+1,pos[1]+1,pos[2]), val(pos[0]+1,pos[1]+1,pos[2]+1)};
   Assert(v[0]==v[0]); Assert(v[1]==v[1]);
   Assert(v[2]==v[2]); Assert(v[3]==v[3]);
   Assert(v[4]==v[4]); Assert(v[5]==v[5]);
   Assert(v[6]==v[6]); Assert(v[7]==v[7]);
   double tol=1e-4*(fabs(v[0])+fabs(v[1])+fabs(v[2])+fabs(v[3])
                   +fabs(v[4])+fabs(v[5])+fabs(v[6])+fabs(v[7]));

   Point p0=ray(t0);
   double x0=p0.x-pos[0], y0=p0.y-pos[1], z0=p0.z-pos[2];
   double dt=t1-t0;
   double v0,v1,v2,v3;
   double A,B,C,D;
   double T0, T1, V0, V1;
   double TM, VM, alpha;
   double discrim;

   v0=eval(v,x0,y0,z0); Assert(v0==v0);
   if(fabs(v0)<tol){
      if(!tHit) return true;
      TM=0;
      goto process_intersection;
   }
   v1=eval(v,x0+0.3333333333333*dt*ray.d[0],
             y0+0.3333333333333*dt*ray.d[1],
             z0+0.3333333333333*dt*ray.d[2]); Assert(v1==v1);
   if(fabs(v1)<tol){
      if(!tHit) return true;
      TM=0.3333333333333333;
      goto process_intersection;
   }
   v2=eval(v,x0+0.6666666666667*dt*ray.d[0],
             y0+0.6666666666667*dt*ray.d[1],
             z0+0.6666666666667*dt*ray.d[2]); Assert(v2==v2);
   if(fabs(v2)<tol){
      if(!tHit) return true;
      TM=0.6666666666666667;
      goto process_intersection;
   }
   v3=eval(v,x0+dt*ray.d[0],
             y0+dt*ray.d[1],
             z0+dt*ray.d[2]); Assert(v3==v3);
   if(fabs(v3)<tol){
      if(!tHit) return true;
      TM=1;
      goto process_intersection;
   }

   A=-4.5*v0+13.5*v1-13.5*v2+4.5*v3,
   B=9*v0-22.5*v1+18*v2-4.5*v3,
   C=-5.5*v0+9*v1-4.5*v2+v3,
   D=v0;  // coefficients of cubic: A*T^3+B*T^2+C*T+D  (where T=(t-t0)/(t1-t0))

#define SIGNCHANGE(a,b) (((a)>=0 && (b)<=0) || ((a)<=0 && (b)>=0))
   // look for first subinterval of [0,1] with a sign change in the cubic
   discrim=4*B*B-12*A*C;
   if(discrim<=0){ // if the cubic is strictly monotonic
      if(SIGNCHANGE(v0,v3)){
         T0=0; T1=1;
         V0=v0; V1=v3;
      }else
         return false;
   }else{ // we need to divide up the cubic into chunks to test for roots
      double rootDiscrim=sqrt(discrim), q;
      if(B<0) q=-.5*(B-rootDiscrim);
      else q=-.5*(B+rootDiscrim);
      double E1=q/A, E2=C/q;
      if(E2<E1) swap(E1,E2);
      // look for the intervals we need to check
      if(E1<0){
         if(E2<0){ // check [0,1]
            if(SIGNCHANGE(v0,v3)){
               T0=0; T1=1;
               V0=v0; V1=v3;
            }else
               return false;
         }else{
            if(E2<1){ // check [0,E2] and [E2,1]
               double E2v=D+E2*(C+E2*(B+E2*A));
               if(fabs(E2v)<tol){
                  if(!tHit) return true;
                  TM=E2;
                  goto process_intersection;
               }
               if(SIGNCHANGE(v0,E2v)){
                  T0=0; T1=E2;
                  V0=v0; V1=E2v;
               }else if(SIGNCHANGE(E2v,v3)){
                  T0=E2; T1=1;
                  V0=E2v; V1=v3;
               }else
                  return false;
            }else{ // check [0,1]
               if(SIGNCHANGE(v0,v3)){
                  T0=0; T1=1;
                  V0=v0; V1=v3;
               }else
                  return false;
            }
         }
      }else{
         if(E2<1){ // check [0,E1] [E1,E2] and [E2,1]
            double E1v=D+E1*(C+E1*(B+E1*A));
            if(fabs(E1v)<tol){
               if(!tHit) return true;
               TM=E1;
               goto process_intersection;
            }
            if(SIGNCHANGE(v0,E1v)){
               T0=0; T1=E1;
               V0=v0; V1=E1v;
            }else{
               double E2v=D+E2*(C+E2*(B+E2*A));
               if(fabs(E2v)<tol){
                  if(!tHit) return true;
                  TM=E2;
                  goto process_intersection;
               }
               if(SIGNCHANGE(E1v,E2v)){
                  T0=E1; T1=E2;
                  V0=E1v; V1=E2v;
               }else if(SIGNCHANGE(E2v,v3)){
                  T0=E2; T1=1;
                  V0=E2v; V1=v3;
               }else
                  return false;
            }
         }else{
            if(E1<1){ // check [0,E1] and [E1,1]
               double E1v=D+E1*(C+E1*(B+E1*A));
               if(fabs(E1v)<tol){
                  if(!tHit) return true;
                  TM=E1;
                  goto process_intersection;
               }
               if(SIGNCHANGE(v0,E1v)){
                  T0=0; T1=E1;
                  V0=v0; V1=E1v;
               }else if(SIGNCHANGE(E1v,v3)){
                  T0=E1; T1=1;
                  V0=E1v; V1=v3;
               }else
                  return false;
            }else{ // check [0,1]
               if(SIGNCHANGE(v0,v3)){
                  T0=0; T1=1;
                  V0=v0; V1=v3;
               }else
                  return false;
            }
         }
      }
   }

   // early exit if we don't care exactly where the intersection is
   if(!tHit) return true;

   // now do a few iterations of secant search
   if(V0<0){
      int i;
      for(i=0; i<5; ++i){
         alpha=V1/(V1-V0);
         TM=alpha*T0+(1-alpha)*T1;
         VM=D+TM*(C+TM*(B+TM*A));
         if(fabs(VM)<tol)
            break;
         else if(VM>0){
            T1=TM; V1=VM;
         }else{
            T0=TM; V0=VM;
         }
      }
   }else{
      int i;
      for(i=0; i<5; ++i){
         alpha=V1/(V1-V0);
         TM=alpha*T0+(1-alpha)*T1;
         VM=D+TM*(C+TM*(B+TM*A));
         if(fabs(VM)<tol)
            break;
         else if(VM<0){
            T1=TM; V1=VM;
         }else{
            T0=TM; V0=VM;
         }
      }
   }

   process_intersection:
   dt=dt*TM;
   *tHit=t0+dt;
   Assert(dg!=0);
   dg->p=ObjectToWorld(ray(*tHit));
   float xm=x0+dt*ray.d[0], ym=y0+dt*ray.d[1], zm=z0+dt*ray.d[2];
   float xr=1.f-xm, yr=1.f-ym, zr=1.f-zm;
   dg->nn[0]=yr*(zr*(v[4]-v[0])+zm*(v[5]-v[1]))
            +ym*(zr*(v[6]-v[2])+zm*(v[7]-v[3]));
   dg->nn[1]=xr*(zr*(v[2]-v[0])+zm*(v[3]-v[1]))
            +xm*(zr*(v[6]-v[4])+zm*(v[7]-v[5]));
   dg->nn[2]=xr*(yr*(v[1]-v[0])+ym*(v[3]-v[2]))
            +xm*(yr*(v[5]-v[4])+ym*(v[7]-v[6]));
   //dg->nn=Normal(ray(*tHit)-Point(0.5*dim[0],0.5*dim[1],0.5*dim[2]));
   dg->nn/=dg->nn.Length();
   dg->nn=ObjectToWorld(dg->nn);
   if(reverseOrientation)
      dg->nn*=-1.f;
   dg->u=0;
   dg->v=0;
   dg->shape=this;
   MakeTangentBasis(dg->nn, dg->dpdu, dg->dpdv);
   dg->dndu=Vector(0,0,0);
   dg->dndv=Vector(0,0,0);
   return true;
}

void GridImplicit::
GetShadingGeometry(const Transform &obj2world,
                   const DifferentialGeometry &dg,
                   DifferentialGeometry *dgShading) const
{
   *dgShading=dg;

   Point p=WorldToObject(dg.p);
   int i=Float2Int(p.x), j=Float2Int(p.y), k=Float2Int(p.z);
   // we rely on the user providing a band of empty cells around surface
   if(i<1) i=1; else if(i>dim[0]-3) i=dim[0]-3;
   if(j<1) j=1; else if(j>dim[1]-3) j=dim[1]-3;
   if(k<1) k=1; else if(k>dim[2]-3) k=dim[2]-3;
   // figure out coefficients for trilinear interpolation
   float xm=p.x-i, xr=1.f-xm;
   float ym=p.y-j, yr=1.f-ym;
   float zm=p.z-k, zr=1.f-zm;
   float c0=xr*yr*zr, c1=xr*yr*zm, c2=xr*ym*zr, c3=xr*ym*zm,
         c4=xm*yr*zr, c5=xm*yr*zm, c6=xm*ym*zr, c7=xm*ym*zm;
   dgShading->nn[0]=c0*(val(i+1,j,  k  ) - val(i-1,j,  k  ))
                   +c1*(val(i+1,j,  k+1) - val(i-1,j,  k+1))
                   +c2*(val(i+1,j+1,k  ) - val(i-1,j+1,k  ))
                   +c3*(val(i+1,j+1,k+1) - val(i-1,j+1,k+1))
                   +c4*(val(i+2,j,  k  ) - val(i,  j,  k  ))
                   +c5*(val(i+2,j,  k+1) - val(i,  j,  k+1))
                   +c6*(val(i+2,j+1,k  ) - val(i,  j+1,k  ))
                   +c7*(val(i+2,j+1,k+1) - val(i,  j+1,k+1));
   dgShading->nn[1]=c0*(val(i,  j+1,k  ) - val(i,  j-1,k  ))
                   +c1*(val(i,  j+1,k+1) - val(i,  j-1,k+1))
                   +c2*(val(i,  j+2,k  ) - val(i,  j,  k  ))
                   +c3*(val(i,  j+2,k+1) - val(i,  j,  k+1))
                   +c4*(val(i+1,j+1,k  ) - val(i+1,j-1,k  ))
                   +c5*(val(i+1,j+1,k+1) - val(i+1,j-1,k+1))
                   +c6*(val(i+1,j+2,k  ) - val(i+1,j,  k  ))
                   +c7*(val(i+1,j+2,k+1) - val(i+1,j,  k+1));
   dgShading->nn[2]=c0*(val(i,  j,  k+1) - val(i,  j,  k-1))
                   +c1*(val(i,  j,  k+2) - val(i,  j,  k  ))
                   +c2*(val(i,  j+1,k+1) - val(i,  j+1,k-1))
                   +c3*(val(i,  j+1,k+2) - val(i,  j+1,k  ))
                   +c4*(val(i+1,j,  k+1) - val(i+1,j,  k-1))
                   +c5*(val(i+1,j,  k+2) - val(i+1,j,  k  ))
                   +c6*(val(i+1,j+1,k+1) - val(i+1,j+1,k-1))
                   +c7*(val(i+1,j+1,k+2) - val(i+1,j+1,k  ));
   dgShading->nn/=dgShading->nn.Length();
   dgShading->nn=ObjectToWorld(dgShading->nn);
   if(reverseOrientation)
      dgShading->nn*=-1.f;
   MakeTangentBasis(dgShading->nn, dgShading->dpdu, dgShading->dpdv);
}

//============================== CreateShape =================================
extern "C" DLLEXPORT
Shape *CreateShape(const Transform &o2w,
                   bool reverseOrientation,
                   const ParamSet &params)
{
   string filename=params.FindOneString("filename","");
   return new GridImplicit(o2w, reverseOrientation, filename.c_str());
}

