////////////////////////////////////////
// Jeremy R. Faller
// (c) Digital Sweetener, 2007
//
// This software is distributed under the terms of the MIT license:
//  http://www.opensource.org/licenses/mit-license.php
//
// The original author requests that users send a short e-mail to
//  software-mit@digitalsweetener.com
// describing where and how the software is being used.  Compliance
// with this condition is optional.


////////////////////////////////////////
// Includes
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <stdio.h>
#include "mtrx.h"

#define SWAP_(a_, b_)   \
do {					\
	MTRX_Element_t temp;\
	temp = (a_);		\
	(a_) = (b_);		\
	(b_) = temp;		\
} while(0)


////////////////////////////////////////////////////////////////////////////////
#pragma mark -
#pragma mark Creators/Destructors


////////////////////////////////////////
// Create an empty matrix
void MTRX_CreateEmpty(MTRX_t *inMat)
{
    memset(inMat, 0, sizeof(inMat));
}



////////////////////////////////////////
// Create a Matrix
//  You can pass in a pointer to a memory location where the matrix will be stored
bool MTRX_Create(MTRX_t *inMat, unsigned inRows, unsigned inCols, MTRX_Element_t *inVals)
{
    inMat->rows = inRows;
    inMat->cols = inCols;
    inMat->origSize = inRows * inCols;
    if (inVals == 0)
    {
#if defined(MTRX_NO_MALLOC_)
		return false;
#else
		inMat->onHeap = true;
        if ((inMat->vals = malloc(sizeof(MTRX_Element_t) * inRows * inCols)) == 0)
        	return false;
        memset(inMat->vals, 0, sizeof(MTRX_Element_t) * inRows * inCols);
#endif
    }
    else
    {
        inMat->onHeap = false;
        inMat->vals = inVals;
    }
    return true;
}


////////////////////////////////////////
// Create an identity matrix
bool MTRX_CreateIdentity(MTRX_t *inMat, unsigned inSize, MTRX_Element_t *inVals)
{
    unsigned i;
    if (MTRX_Create(inMat, inSize, inSize, inVals) == false)
    	return false;
    for(i = 0; i < inSize * inSize; i += inSize + 1)
        inMat->vals[i] = 1;
    return true;
}


////////////////////////////////////////
// Destroy a matrix
void MTRX_Destroy(MTRX_t *inMat)
{
    if (inMat->onHeap)
#if defined(MTRX_NO_MALLOC_)
		;
#else
		free(inMat->vals);
#endif
    memset(inMat, 0, sizeof(inMat));
}



////////////////////////////////////////
// Sets a matrix to a given size
//  Destroys all values in the matrix
static bool
MTRX_SetSize(MTRX_t *outMat, unsigned inRows, unsigned inCols)
{
    if(outMat->origSize < inCols * inRows)
    {
#if defined(MTRX_NO_MALLOC_)
		return false;
#else
		MTRX_Destroy(outMat);
        MTRX_Create(outMat, inRows, inCols, 0);
#endif
    }
    outMat->rows = inRows;
    outMat->cols = inCols;
    return true;
}


////////////////////////////////////////
// Copy a matrix
bool MTRX_Copy(MTRX_t *outMat, MTRX_t *inMat)
{
    if (MTRX_SetSize(outMat, inMat->rows, inMat->cols) == false)
    	return false;
    memcpy(outMat->vals, inMat->vals, inMat->rows * inMat->cols * sizeof(MTRX_Element_t));
    return true;
}




////////////////////////////////////////////////////////////////////////////////
#pragma mark -
#pragma mark Accessors


////////////////////////////////////////
// Get an element
MTRX_Element_t MTRX_Access(MTRX_t *inMat, unsigned inRow, unsigned inCol)
{
    return inMat->vals[(inMat->cols * inRow + inCol)];
}


////////////////////////////////////////
// Get Rows
unsigned MTRX_Rows(MTRX_t *inMat)
{
    return inMat->rows;
}


////////////////////////////////////////
// Get Rows
unsigned MTRX_Cols(MTRX_t *inMat)
{
    return inMat->cols;
}


////////////////////////////////////////
// Get Matrix size
unsigned MTRX_Size(MTRX_t *inMat)
{
    return inMat->rows * inMat->cols;
}


////////////////////////////////////////
// Get Matrix allocated size
unsigned MTRX_AllocatedSize(MTRX_t *inMat)
{
    return inMat->origSize;
}


////////////////////////////////////////
// Is this matrix on the heap?
bool MTRX_OnHeap(MTRX_t *inMat)
{
    return inMat->onHeap;
}


////////////////////////////////////////
// Set a matrix to a given array
void MTRX_SetToArray(MTRX_t *ioMat, unsigned inRows, unsigned inCols, MTRX_Element_t *inData)
{
    MTRX_Destroy(ioMat);
    ioMat->rows = inRows;
    ioMat->cols = inCols;
    ioMat->origSize = inRows * inCols;
    ioMat->onHeap = false;
    ioMat->vals = inData;
}


////////////////////////////////////////
// Set a matrix to a given array
bool MTRX_CopyFromArray(MTRX_t *ioMat, unsigned inRows, unsigned inCols, MTRX_Element_t *inData)
{
	if (MTRX_SetSize(ioMat, inRows, inCols) == false)
		return false;
    memcpy(ioMat->vals, inData, inRows * inCols * sizeof(MTRX_Element_t));
    return true;
}


////////////////////////////////////////
// Copy out a matrix
void MTRX_CopyToData(MTRX_t *inMat, unsigned inSize, void *outData)
{
    memcpy(outData, inMat->vals, inSize);
}


////////////////////////////////////////
// Copy out a matrix
void MTRX_CopyToArray(MTRX_t *inMat, unsigned inNumElements, MTRX_Element_t *outData)
{
    memcpy(outData, inMat->vals, inNumElements * sizeof(MTRX_Element_t));
}


////////////////////////////////////////////////////////////////////////////////
#pragma mark -
#pragma mark Math


////////////////////////////////////////
// Adds a scalar to a matrix
// Can be used in place
bool MTRX_ScalarAdd(MTRX_t *outMat, MTRX_t *inMat, MTRX_Element_t inVal)
{
    unsigned    i, num;

    if (MTRX_SetSize(outMat, inMat->rows, inMat->cols) == false)
    	return false;
    num = inMat->rows * inMat->cols;
    for(i = 0; i < num; ++i)
        outMat->vals[i] = inMat->vals[i] + inVal;
    return true;
}


////////////////////////////////////////
// Subtract a scalar from a matrix
bool MTRX_ScalarSubtract(MTRX_t *outMat, MTRX_t *inMat, MTRX_Element_t inVal)
{
    return MTRX_ScalarAdd(outMat, inMat, -inVal);
}


////////////////////////////////////////
// Multiply a scalar to a matrix
// Can be used in place
bool MTRX_ScalarMultiply(MTRX_t *outMat, MTRX_t *inMat, MTRX_Element_t inVal)
{
    unsigned    i, num;

    if (MTRX_SetSize(outMat, inMat->rows, inMat->cols) == false)
    	return false;
    num = inMat->rows * inMat->cols;
    for(i = 0; i < num; ++i)
        outMat->vals[i] = inMat->vals[i] * inVal;
    return true;
}


////////////////////////////////////////
// Add two matrices together
// Can be used in place
bool MTRX_Add(MTRX_t *outMat, MTRX_t *inMat1, MTRX_t *inMat2)
{
    unsigned    i, num;

    // Check that the sizes are the same
    if ((inMat1->rows != inMat2->rows) || (inMat1->cols != inMat2->cols))
        return false;

    // Add the matrices
    if (MTRX_SetSize(outMat, inMat1->rows, inMat1->cols) == false)
    	return false;
    num = inMat1->rows * inMat1->cols;
    for(i = 0; i < num; ++i)
        outMat->vals[i] = inMat1->vals[i] + inMat2->vals[i];

    return true;
}


////////////////////////////////////////
// Add two matrices together
// Can be used in place
bool MTRX_Subtract(MTRX_t *outMat, MTRX_t *inMat1, MTRX_t *inMat2)
{
    unsigned    i, num;

    // Check that the sizes are the same
    if ((inMat1->rows != inMat2->rows) || (inMat1->cols != inMat2->cols))
        return false;

    // Add the matrices
    if (MTRX_SetSize(outMat, inMat1->rows, inMat1->cols) == false)
    	return false;
    num = inMat1->rows * inMat1->cols;
    for(i = 0; i < num; ++i)
        outMat->vals[i] = inMat1->vals[i] - inMat2->vals[i];

    return true;
}


////////////////////////////////////////
// Multiply two matrices
// CANNOT be used in place
bool MTRX_Multiply(MTRX_t *outMat, MTRX_t *inMat1, MTRX_t *inMat2)
{
    unsigned    i, j, k;
    unsigned    rows, cols, other;

    // Check the matrix size
    if(inMat1->cols != inMat2->rows)
        return false;

    // Setup the result matrix
    rows = inMat1->rows;
    cols = inMat2->cols;
    if (MTRX_SetSize(outMat, rows, cols) == false)
    	return false;
    memset(outMat->vals, 0, rows * cols * sizeof(MTRX_Element_t));

    other = inMat1->cols;
    for(i = 0; i < rows; ++i)
        for(j = 0; j < cols; ++j)
            for(k = 0; k < other; ++k)
                outMat->vals[i * cols + j] += inMat1->vals[i * other + k] * inMat2->vals[k * cols + j];
    return true;
}



////////////////////////////////////////
// Inverts a 2x2 matrix
static bool MTRX_Invert_2x2(MTRX_t *ioMat)
{
    MTRX_Element_t  prod;

    // Can we invert it?
    if ((ioMat->rows != 2) || (ioMat->cols != 2))
        return false;

    // Check that the product will be okay
    prod = ioMat->vals[0] * ioMat->vals[3] - ioMat->vals[1] * ioMat->vals[2];
    if(prod == 0)
        return false;
    prod = 1.0 / prod;

    // Swap the elements
    SWAP_(ioMat->vals[0], ioMat->vals[3]);

    // Multiply
    ioMat->vals[1] *= -1;
    ioMat->vals[2] *= -1;

    // Mutiply by scalar
    MTRX_ScalarMultiply(ioMat, ioMat, prod);
    return true;
}


////////////////////////////////////////
// Invert the given matrix
//  Matrix must be square.
//  Returns true if the matrix could be inverted, otherwise false.
//  If the matrix cannot be inverted, the matrix might be trampled.
//  The second parameter is temporary storage for the pivoting, which
//  is not used if n == 2.  For a nxn matrix, the storage requirements
//  are 3n.  If no array is is passed in, a temporary buffer is
//  allocated out of the heap.
//
//  Uses the closed form solution for 2x2, otherwise it uses Gauss-Jordan
//  elimination will full pivoting.
//
//  The algorithm was adapted from Numerical Receipes in C.
bool MTRX_Invert(MTRX_t *ioMat, unsigned *indxc)
{
    int             size;
    int             i, icol, irow, j, k;
    unsigned        *indxr, *ipiv;
    bool            didAlloc, couldInvert;
    
   	// Remove warning when optimization and warnings are high
   	icol = irow = 0;

    // Handle degenerate and easy cases
    if((ioMat->rows == 0) || (ioMat->cols == 0))
        return false;
    if(ioMat->rows != ioMat->cols)
        return false;
    if ((ioMat->rows == 2) && (ioMat->cols == 2))
        return MTRX_Invert_2x2(ioMat);

    //
    // Okay, we handled the simple cases, now it's time to get dirty
    // We use Gauss-Jordan Elimination to perform inversion from here on in
    //

    // Setup variables
    couldInvert = true;
    size = ioMat->rows;
    didAlloc = false;

    // Allocate temp storage if required.
    if (indxc == 0)
    {
#if defined(MTRX_NO_MALLOC_)
		return false;
#else
        didAlloc = true;
        if ((indxc = malloc(sizeof(indxc[0]) * 3 * size)) == 0)
        	return false;
#endif
    }
    memset(indxc, 0, sizeof(indxc[0]) * 3 * size);
    indxr = indxc + size;
    ipiv = indxr + size;

    for(i = 0; i < size; ++i)
    {
        MTRX_Element_t  theMax = 0;
        MTRX_Element_t  pivot;

        // Find the pivot
        for(j = 0; j < size; j++)
            if(ipiv[j] != 1)
                for(k = 0; k < size; ++k)
                    if(ipiv[k] == 0)
                    {
                        MTRX_Element_t  temp = fabs(ioMat->vals[j * size + k]);
                        if (temp >= theMax)
                        {
                            theMax = temp;
                            irow = j;
                            icol = k;
                        }
                    }
        ++ipiv[icol];

        // We now have the pivot, so interchange rows
        if (irow != icol)
            for(j = 0; j < size; ++j) SWAP_(ioMat->vals[irow * size + j], ioMat->vals[icol * size + j]);

        indxr[i] = irow;
        indxc[i] = icol;

        // Check for singularity
        pivot = ioMat->vals[icol * size + icol];
        if (pivot == 0)
        {
            couldInvert = false;
            break;
        }

        // Now divide through by pivot
        ioMat->vals[icol * size + icol] = 1;
        for(j = 0; j < size; j++)
            ioMat->vals[icol * size + j] /= pivot;

        // Now reduce the rows
        for(j = 0; j < size; ++j)
            if(j != icol)
            {
                MTRX_Element_t  temp;
                temp = ioMat->vals[j * size + icol];
                ioMat->vals[j * size + icol] = 0;
                for(k = 0; k < size; ++k)
                    ioMat->vals[j * size + k] -= ioMat->vals[icol * size + k] * temp;
            }
    }

    // Unscramble the solution
    if (couldInvert)
    {
        i = size - 1;
        do
        {
            if(indxr[i] != indxc[i])
                for(j = 0; j < size; ++j)
                    SWAP_(ioMat->vals[j * size + indxr[i]], ioMat->vals[j * size + indxc[i]]);
        }
        while(i-- != 0);
    }

    // Cleanup
#if !defined(MTRX_NO_MALLOC_)
    if (didAlloc)
		free(indxc);
#endif

    return couldInvert;
}


////////////////////////////////////////
// Find the index to swap
static inline unsigned Pred(unsigned k, unsigned M, unsigned L)
{
	return (k % M)*L + k / M;
}


////////////////////////////////////////
// Transpose a matrix (inplace)
//	This is an expensive operation.  When possible use the out of place
//	transpose code.
bool MTRX_Transpose_InPlace(MTRX_t *ioMat)
{
	unsigned	rows, cols;
	unsigned	i, j, k, stillToMove, total;
	
	// Transpose dimensions
	rows = ioMat->rows;
	cols = ioMat->cols;
	ioMat->cols = rows;
	ioMat->rows = cols;
	
	// Handle vectors (in which case we're done by swapping row/col
	if ((rows == 1) || (cols == 1))
		return false;
		
	// Handle the easy case of m=n
	if (rows == cols)
	{
		for(i = 0; i < rows; ++i)
			for(j = 1; j < cols; ++j)
				SWAP_(ioMat->vals[i * cols + j], ioMat->vals[j * cols + i]);
		return true;
	}
	
	// Swap the matrix in place
	total = rows * cols;
	stillToMove = total;
	for (i = 0; stillToMove > 0; ++i)
	{
		for (j = Pred(i, rows, cols); j > i; j = Pred(j, rows, cols));
		if (j < i) continue;
		for (k = i, j = Pred(i, rows, cols); j != i; k = j, j = Pred(j, rows, cols))
		{
			SWAP_(ioMat->vals[k], ioMat->vals[j]);
			--stillToMove;
		}
		--stillToMove;
	}
	return true;
}


////////////////////////////////////////
// Transpose a matrix out of place
bool MTRX_Transpose(MTRX_t *outMat, MTRX_t *inMat)
{
	unsigned	i, j, rows, cols;
	
	rows = inMat->cols;
	cols = inMat->rows;
	if (MTRX_SetSize(outMat, rows, cols) == false)
		return false;
	for(i = 0; i < rows; ++i)
		for(j = 0; j < cols; ++j)
			outMat->vals[i * cols + j] = inMat->vals[j * rows + i];
	return true;
}


////////////////////////////////////////
// Print a matrix
void MTRX_Print(FILE *inFP, MTRX_t *inMat)
{
    unsigned    i, j;
    fprintf(inFP, "(%u,%u)", inMat->rows, inMat->cols);
    fprintf(inFP, " -- origSize %u", inMat->origSize);
    fprintf(inFP, " [%son heap]", (inMat->onHeap)?(""):("not "));
    fprintf(inFP, "\n");
    for(i = 0; i < inMat->rows; ++i)
    {
        for(j = 0; j < inMat->cols; ++j)
            fprintf(inFP, "%g\t", inMat->vals[i * inMat->cols + j]);
        fprintf(inFP, "\n");
    }
}
