

	/************************************************************************
	 * Code to solve a simple linear system of N equstions for N unknowns,
	 * stolen from the "match" program.
	 *
	 * MWR 1/20/2003
	 */


#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <ctype.h>
#include <math.h>


#define shAssert(x) if((x)!=1){fprintf(stderr,"assertion fails in file %s, line %d\n",__FILE__,__LINE__);exit(1);}


#define MATRIX_TOL         1.0e-12
#define SH_SUCCESS         0
#define SH_GENERIC_ERROR   1


#undef DEBUG
#undef DEBUG3

static int gauss_pivot(double **matrix, int num, double *vector,
                       double *biggest_val, int row);

#ifdef DEBUG
static void print_matrix(double **matrix, int n);
#endif
#ifdef DEBUG3
static void test_routine (void);
#endif
static void *shMalloc(int nbytes);
static void shFree(void *vptr);
static void shError(char *format, ...);
static void shFatal(char *format, ...);


/***************************************************************************
 * PROCEDURE: gauss_matrix
 *
 * DESCRIPTION: 
 * Given a square 2-D 'num'-by-'num' matrix, called "matrix", and given
 * a 1-D vector "vector" of 'num' elements, find the 1-D vector
 * called "solution_vector" which satisfies the equation
 *
 *      matrix * solution_vector  =  vector
 *
 * where the * above represents matrix multiplication.
 *
 * What we do is to use Gaussian elimination (with partial pivoting)
 * and back-substitution to find the solution_vector.  
 * We do not pivot in place, but physically move values -- it
 * doesn't take much time in this application.  After we have found the 
 * "solution_vector", we replace the contents of "vector" with the
 * "solution_vector".
 * 
 * This is a common algorithm.  See any book on linear algebra or
 * numerical solutions; for example, "Numerical Methods for Engineers,"
 * by Steven C. Chapra and Raymond P. Canale, McGraw-Hill, 1998,
 * Chapter 9.
 *
 * If an error occurs (if the matrix is singular), this prints an error
 * message and returns with error code.
 *
 * RETURN:
 *    SH_SUCCESS          if all goes well
 *    SH_GENERIC_ERROR    if not -- if matrix is singular
 *
 * </AUTO>
 */


int
gauss_matrix
   (
   double **matrix,       /* I/O: the square 2-D matrix we'll invert */
                          /*      will hold inverse matrix on output */
   int num,               /* I: number of rows and cols in matrix */
   double *vector         /* I/O: vector which holds "b" values in input */
                          /*      and the solution vector "x" on output */
   )
{
  int i, j, k;
  double *biggest_val;
  double *solution_vector;
  double factor;
  double sum;

#ifdef DEBUG
  print_matrix(matrix, num);
#endif

  biggest_val = (double *) shMalloc(num*sizeof(double));
  solution_vector = (double *) shMalloc(num*sizeof(double));

  /* 
   * step 1: we find the largest value in each row of matrix,
   *         and store those values in 'biggest_val' array.
   *         We use this information to pivot the matrix.
   */
  for (i = 0; i < num; i++) {
    biggest_val[i] = fabs(matrix[i][0]);
    for (j = 1; j < num; j++) {
      if (fabs(matrix[i][j]) > biggest_val[i]) {
        biggest_val[i] = fabs(matrix[i][j]);
      }
    }
    if (biggest_val[i] == 0.0) {
      shError("gauss_matrix: biggest val in row %d is zero", i);
      shFree(biggest_val);
      shFree(solution_vector);
      return(SH_GENERIC_ERROR);
    }
  }

  /* 
   * step 2: we use Gaussian elimination to convert the "matrix"
   *         into a triangular matrix, in which the values of all
   *         elements below the diagonal is zero.
   */
  for (i = 0; i < num - 1; i++) {

    /* pivot this row (if necessary) */
    if (gauss_pivot(matrix, num, vector, biggest_val, i) == SH_GENERIC_ERROR) {
      shError("gauss_matrix: singular matrix");
      shFree(biggest_val);
      shFree(solution_vector);
      return(SH_GENERIC_ERROR);
    }
      
    if (fabs(matrix[i][i]/biggest_val[i]) < MATRIX_TOL) {
      shError("gauss_matrix: Y: row %d has tiny value %f / %f", 
                  i, matrix[i][i], biggest_val[i]);
      shFree(biggest_val);
      shFree(solution_vector);
      return(SH_GENERIC_ERROR);
    }

    /* we eliminate this variable in all rows below the current one */
    for (j = i + 1; j < num; j++) {
      factor = matrix[j][i]/matrix[i][i];
      for (k = i + 1; k < num; k++) {
        matrix[j][k] -= factor*matrix[i][k];
      }
      /* and in the vector, too */
      vector[j] -= factor*vector[i];
    }

  }

  /* 
   * make sure that the last row's single remaining element
   * isn't too tiny 
   */
  if (fabs(matrix[num-1][num-1]/biggest_val[num-1]) < MATRIX_TOL) {
    shError("gauss_matrix: Z: row %d has tiny value %f / %f", 
                num, matrix[num-1][num-1], biggest_val[num-1]);
    shFree(biggest_val);
    shFree(solution_vector);
    return(SH_GENERIC_ERROR);
  }

  /* 
   * step 3: we can now calculate the solution_vector values
   *         via back-substitution; we start at the last value in the
   *         vector (at the "bottom" of the vector) and work 
   *         upwards towards the top.
   */
  solution_vector[num-1] = vector[num-1] / matrix[num-1][num-1];
  for (i = num - 2; i >= 0; i--) {
    sum = 0.0;
    for (j = i + 1; j < num; j++) {
      sum += matrix[i][j]*solution_vector[j];
    }
    solution_vector[i] = (vector[i] - sum) / matrix[i][i];
  }


  /*
   * step 4: okay, we've found the values in the solution vector!
   *         We now replace the input values in 'vector' with these
   *         solution_vector values, and we're done.
   */
  for (i = 0; i < num; i++) {
    vector[i] = solution_vector[i];
  }

  
  /* clean up */
  shFree(solution_vector);
  shFree(biggest_val);

  return(SH_SUCCESS);
}


/***************************************************************************
 * PROCEDURE: gauss_pivot
 *
 * DESCRIPTION: 
 * This routine is called by "gauss_matrix".  Given a square "matrix"
 * of "num"-by-"num" elements, and given a "vector" of "num" elements,
 * and given a particular "row" value, this routine finds the largest
 * value in the matrix at/below the given "row" position.  If that
 * largest value isn't in the given "row", this routine switches
 * rows in the matrix (and in the vector) so that the largest value
 * will now be in "row".
 *
 * RETURN:
 *    SH_SUCCESS          if all goes well
 *    SH_GENERIC_ERROR    if not -- if matrix is singular
 *
 * </AUTO>
 */

#define SWAP(a,b)  { double temp = (a); (a) = (b); (b) = temp; }

static int
gauss_pivot
   (
   double **matrix,       /* I/O: a square 2-D matrix we are inverting */
   int num,               /* I: number of rows and cols in matrix */
   double *vector,        /* I/O: vector which holds "b" values in input */
   double *biggest_val,   /* I: largest value in each row of matrix */
   int row                /* I: want to pivot around this row */
   )
{
  int i;
  int col, pivot_row;
  double big, other_big;

  /* sanity checks */
  shAssert(matrix != NULL);
  shAssert(vector != NULL);
  shAssert(row < num);


  pivot_row = row;
  big = fabs(matrix[row][row]/biggest_val[row]); 
#ifdef DEBUG
  print_matrix(matrix, num);
  printf(" biggest_val:  ");
  for (i = 0; i < num; i++) {
    printf("%9.4e ", biggest_val[i]);
  }
  printf("\n");
  printf("  gauss_pivot: row %3d  %9.4e %9.4e %12.5e ", 
                  row, matrix[row][row], biggest_val[row], big);
#endif

  for (i = row + 1; i < num; i++) {
    other_big = fabs(matrix[i][row]/biggest_val[i]);
    if (other_big > big) {
      big = other_big;
      pivot_row = i;
    }
  }
#ifdef DEBUG
  printf("  pivot_row %3d  %9.4e %9.4e %12.5e ", 
                  pivot_row, matrix[pivot_row][pivot_row], 
                  biggest_val[pivot_row], big);
#endif

  /* 
   * if another row is better for pivoting, switch it with 'row' 
   *    and switch the corresponding elements in 'vector'
   *    and switch the corresponding elements in 'biggest_val'
   */
  if (pivot_row != row) {
#ifdef DEBUG
    printf("   will swap \n");
#endif
    for (col = row; col < num; col++) {
      SWAP(matrix[pivot_row][col], matrix[row][col]);
    }
    SWAP(vector[pivot_row], vector[row]);
    SWAP(biggest_val[pivot_row], biggest_val[row]);
  }
  else {
#ifdef DEBUG
    printf("    no swap \n");
#endif
  }
  
  return(SH_SUCCESS);
}




/************************************************************************
 * 
 *
 * ROUTINE: print_matrix 
 *
 * DESCRIPTION:
 * print out a nice picture of the given matrix.  
 *
 * For debugging purposes.
 *
 * RETURNS:
 *   nothing
 *
 * </AUTO>
 */

#ifdef DEBUG

static void
print_matrix
   (
   double **matrix,   /* I: pointer to 2-D array to be printed */
   int n              /* I: number of elements in each row and col */
   )
{
   int i, j;

   for (i = 0; i < n; i++) {
      for (j = 0; j < n; j++) {
         printf(" %12.5e", matrix[i][j]);
      }
      printf("\n");
   }
}

#endif /* DEBUG */




/*
 * check to see if my versions of NR routines have bugs.
 * Try to invert a matrix.
 * 
 * debugging only.
 */

#ifdef DEBUG3

static void
test_routine (void)
{
	int i, j, k, n;
	int *permutations;
	double **matrix1, **matrix2, **inverse;
	double *vector;
	double *col;
	double sum;

    fflush(stdout);
    fflush(stderr);
	n = 2;
	matrix1 = (double **) shMalloc(n*sizeof(double *));
	matrix2 = (double **) shMalloc(n*sizeof(double *));
	inverse = (double **) shMalloc(n*sizeof(double *));
	vector = (double *) shMalloc(n*sizeof(double));
   	for (i = 0; i < n; i++) {
		matrix1[i] = (double *) shMalloc(n*sizeof(double));
		matrix2[i] = (double *) shMalloc(n*sizeof(double));
		inverse[i] = (double *) shMalloc(n*sizeof(double));
	}
	permutations = (int *) shMalloc(n*sizeof(int));
	col = (double *) shMalloc(n*sizeof(double));


	/* fill the matrix */
	matrix1[0][0] = 1.0;
	matrix1[0][1] = 2.0;
	matrix1[1][0] = 3.0;
	matrix1[1][1] = 4.0;

	/* fill the vector */
	for (i = 0; i < n; i++) {
		vector[i] = 0;
	}

	/* copy matrix1 into matrix2, so we can compare them later */
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			matrix2[i][j] = matrix1[i][j];
		}
	}

	/* now check */
	printf(" here comes original matrix \n");
	print_matrix(matrix1, n);


	/* now invert matrix1 */
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			inverse[i][j] = matrix1[i][j];
		}
	}
	gauss_matrix(inverse, n, vector);

	/* now check */
	printf(" here comes inverse matrix \n");
	print_matrix(inverse, n);

	/* find out if the product of "inverse" and "matrix2" is identity */
	sum = 0.0;
	for (i = 0; i < n; i++) {
		for (j = 0; j < n; j++) {
			for (k = 0; k < n; k++) {
				sum += inverse[i][k]*matrix2[k][j];
			}
			matrix1[i][j] = sum;
			sum = 0.0;
		}
	}

	printf(" here comes what we hope is identity matrix \n");
	print_matrix(matrix1, n);

    fflush(stdout);
    fflush(stderr);
}

#endif /* DEBUG3 */


   /*********************************************************************
    * ROUTINE: shMalloc
    *
    * Attempt to allocate the given number of bytes.  Return the 
    * memory, if we succeeded, or print an error message and
    * exit with error code if we failed.
    * 
    * RETURNS:
    *      void *             to new memory, if we got it
    */

void *
shMalloc
   (
   int nbytes                /* I: allocate a chunk of this many bytes */
   )
{
   void *vptr;

   if ((vptr = (void *) malloc(nbytes)) == NULL) {
      shError("shMalloc: failed to allocate for %d bytes", nbytes);
      exit(1);
   }
   return(vptr);
}


   /*********************************************************************
    * ROUTINE: shFree
    *
    * Attempt to free the given piece of memory.  
    * 
    * RETURNS:
    *      nothing
    */

void 
shFree
   (
   void *vptr                /* I: free this chunk of memory */
   )
{
   free(vptr);
}



   /*********************************************************************
    * ROUTINE: shError
    *
    * Print the given error message to stderr, but continue to execute.
    * 
    * RETURNS:
    *      nothing
    */

void 
shError
   (
   char *format,             /* I: format part of printf statement */
   ...                       /* I: optional arguments to printf */
   )
{
   va_list ap;

   va_start(ap, format);
   (void) vfprintf(stderr, (const char *)format, ap);
   fputc('\n', stderr);
   fflush(stdout);
   fflush(stderr);
   va_end(ap);
}


   /*********************************************************************
    * ROUTINE: shFatal
    *
    * Print the given error message to stderr, and halt program execution.
    * 
    * RETURNS:
    *      nothing
    */

void 
shFatal
   (
   char *format,             /* I: format part of printf statement */
   ...                       /* I: optional arguments to printf */
   )
{
   va_list ap;

   va_start(ap, format);
   (void) vfprintf(stderr, (const char *)format, ap);
   fputc('\n', stderr);
   fflush(stdout);
   fflush(stderr);
   va_end(ap);
   exit(1);
}


