/* some basic matrix operations */
#include "nr.h"

void matsum(double **a,int anr,int anc,double **b,int bnr,int bnc);
void matdif(double **a,int anr,int anc,double **b,int bnr,int bnc);
void matprod(double **a,int anr,int anc,double **b,int bnr,int bnc,
   double **res);
void trans(double **a,int nr,int nc,double **b);
void outerprod(double *vec1,double *vec2, int n,double **result);
double ldet(double **a,int n);
void invert(double **a,double **res, int n);


void matsum(double **a,int anr,int anc,double **b,int bnr,int bnc)
{
/* adds two matrices and return result in a */
   if(anr != bnr) nrerror("Can't add:A and B have different no. of rows.");
   if(anc != bnc) nrerror("Can't add:A and B have different no. of columns.");
   for(int i=1;i<=anr;i++){
      for(int j=1;j<=anc;j++){
         a[i][j]=a[i][j]+b[i][j];
      }
   }
}

void matdif(double **a,int anr,int anc,double **b,int bnr,int bnc)
{
/* finds difference between two matrices */
   if(anr != bnr) nrerror("Can't subtract:A and B have different no. of rows.");
   if(anc != bnc) nrerror("Can't subtract:A and B have different no. of columns.");
   for(int i=1;i<=anr;i++){
      for(int j=1;j<=anc;j++){
         a[i][j]=a[i][j]-b[i][j];
      }
   }
}
 
void matprod(double **a,int anr,int anc,double **b,int bnr,int bnc,
   double **res)
{
/* multiplies two matrices and returns result in res */
   if(anc != bnr) nrerror("A and B can't be multiplied!");
   for(int i=1;i<=anr;i++){
      for(int j=1;j<=bnc;j++){
         double sum=0.0;
         for(int k=1;k<=anc;k++) sum += a[i][k]*b[k][j];
         res[i][j]=sum;
      }
   }
}
       
void trans(double **a,int nr,int nc,double **b)
{
/* this transposes a matrix nr and nc are for a */
   for(int i=1;i<=nr;i++){
      for(int j=1;j<=nc;j++){
         b[j][i]=a[i][j];
      }
   }
}
      
void outerprod(double *vec1,double *vec2, int n,double **result)
{
/* finds the matrix product of 2 vectors */
   for(int i=1;i<=n;i++){
      for(int j=1;j<=n;j++){
         result[i][j]=vec1[i]*vec2[j];
      }
   }   
}

void choldc(double **a, int n, double p[])
{
/* Uses Cholesky decomposition to break a pd matrix into an L triangular matrix.
The transpose of the L multiplied by the L equals the original matrix.  
Note: this function overwrites a  with the L matrix.  */
        void nrerror(char error_text[]);
        int i,j,k;
        double sum;

        for (i=1;i<=n;i++) {
                for (j=i;j<=n;j++) {
                        for (sum=a[i][j],k=i-1;k>=1;k--) sum -= a[i][k]*a[j][k];
                        if (i == j) {
                                if (sum <= 0.0)
                                        nrerror("choldc failed");
                                p[i]=sqrt(sum);
                        } else a[j][i]=sum/p[i];
                }
        }
}

double ldet(double **a,int n)
{
   /* returns the determinant of a pd matrix*/
   int i,j;
   double d=0.0,*p;
   double **b;
   b=dmatrix(1,n,1,n);
   p=dvector(1,n);
   /* saves a, because choldc overwrites the matrix you give to it */
   for(i=1;i<=n;i++){
      for(j=1;j<=n;j++){
         b[i][j]=a[i][j];
      }
   }
   choldc(b,n,p);
   for(j=1;j<=n;j++) d += p[j];
   free_dvector(p,1,n);
   free_dmatrix(b,1,n,1,n);
   return 2.0*d;
}
 

void invert(double **a,double **res, int n)
/* this function inverts a covariance (sym pd) matrix using choldc
   be careful! a gets overwritten so don't try to use it again */
{
   double *p;
   p=dvector(1,n);   
   for(int i=1;i<=n;i++){
      p[i]=1.0;
   }   
   choldc(a,n,p);
   /* this loop calculates L^-1 */
   for(int i=1;i<=n;i++){
      a[i][i]=1.0/p[i];
      for(int j=i+1;j<=n;j++){
         double sum1=0.0;
         for(int k=i;k<j;k++) sum1 -= a[j][k]*a[k][i];
         a[j][i]=sum1/p[j];
      }
   }
   /* now find t(L^-1)%*%L^-1 */
   for(int i=1;i<=n;i++){
      for(int j=i;j<=n;j++){
         double sum2=0.0;
         for(int k=j;k<=n;k++) sum2 += a[k][i]*a[k][j];
         res[i][j]=sum2;
      }
   }
   for(int i=1;i<=n;i++){
      for(int j=1;j<=i;j++){
         res[i][j]=res[j][i];
      }
   }
   free_dvector(p,1,n);
}