umf_utsolve.c 9.07 KB
/* ========================================================================== */
/* === UMF_utsolve ========================================================== */
/* ========================================================================== */

/* -------------------------------------------------------------------------- */
/* UMFPACK Copyright (c) Timothy A. Davis, CISE,                              */
/* Univ. of Florida.  All Rights Reserved.  See ../Doc/License for License.   */
/* web: http://www.cise.ufl.edu/research/sparse/umfpack                       */
/* -------------------------------------------------------------------------- */

/*  solves U'x = b or U.'x=b, where U is the upper triangular factor of a */
/*  matrix.  B is overwritten with the solution X. */
/*  Returns the floating point operation count */

#include "umf_internal.h"

GLOBAL double
#ifdef CONJUGATE_SOLVE
UMF_uhsolve			/* solve U'x=b  (complex conjugate transpose) */
#else
UMF_utsolve			/* solve U.'x=b (array transpose) */
#endif
(
    NumericType *Numeric,
    Entry X [ ],		/* b on input, solution x on output */
    Int Pattern [ ]		/* a work array of size n */
)
{
    /* ---------------------------------------------------------------------- */
    /* local variables */
    /* ---------------------------------------------------------------------- */

    Entry xk ;
    Entry *xp, *D, *Uval ;
    Int k, deg, j, *ip, col, *Upos, *Uilen, kstart, kend, up,
	*Uip, n, uhead, ulen, pos, npiv, n1, *Ui ;

    /* ---------------------------------------------------------------------- */
    /* get parameters */
    /* ---------------------------------------------------------------------- */

    if (Numeric->n_row != Numeric->n_col) return (0.) ;
    n = Numeric->n_row ;
    npiv = Numeric->npiv ;
    Upos = Numeric->Upos ;
    Uilen = Numeric->Uilen ;
    Uip = Numeric->Uip ;
    D = Numeric->D ;
    kend = 0 ;
    n1 = Numeric->n1 ;

#ifndef NDEBUG
    DEBUG4 (("Utsolve start: npiv "ID" n "ID"\n", npiv, n)) ;
    for (j = 0 ; j < n ; j++)
    {
	DEBUG4 (("Utsolve start "ID": ", j)) ;
	EDEBUG4 (X [j]) ;
	DEBUG4 (("\n")) ;
    }
#endif

    /* ---------------------------------------------------------------------- */
    /* singletons */
    /* ---------------------------------------------------------------------- */

    for (k = 0 ; k < n1 ; k++)
    {
	DEBUG4 (("Singleton k "ID"\n", k)) ;

#ifndef NO_DIVIDE_BY_ZERO
	/* Go ahead and divide by zero if D [k] is zero. */
#ifdef CONJUGATE_SOLVE
	/* xk = X [k] / conjugate (D [k]) ; */
	DIV_CONJ (xk, X [k], D [k]) ;
#else
	/* xk = X [k] / D [k] ; */
	DIV (xk, X [k], D [k]) ;
#endif
#else
	/* Do not divide by zero */
	if (IS_NONZERO (D [k]))
	{
#ifdef CONJUGATE_SOLVE
	    /* xk = X [k] / conjugate (D [k]) ; */
	    DIV_CONJ (xk, X [k], D [k]) ;
#else
	    /* xk = X [k] / D [k] ; */
	    DIV (xk, X [k], D [k]) ;
#endif
	}
#endif

	X [k] = xk ;
	deg = Uilen [k] ;
	if (deg > 0 && IS_NONZERO (xk))
	{
	    up = Uip [k] ;
	    Ui = (Int *) (Numeric->Memory + up) ;
	    up += UNITS (Int, deg) ;
	    Uval = (Entry *) (Numeric->Memory + up) ;
	    for (j = 0 ; j < deg ; j++)
	    {
		DEBUG4 (("  k "ID" col "ID" value", k, Ui [j])) ;
		EDEBUG4 (Uval [j]) ;
		DEBUG4 (("\n")) ;
#ifdef CONJUGATE_SOLVE
		/* X [Ui [j]] -= xk * conjugate (Uval [j]) ; */
		MULT_SUB_CONJ (X [Ui [j]], xk, Uval [j]) ;
#else
		/* X [Ui [j]] -= xk * Uval [j] ; */
		MULT_SUB (X [Ui [j]], xk, Uval [j]) ;
#endif
	    }
	}
    }

    /* ---------------------------------------------------------------------- */
    /* nonsingletons */
    /* ---------------------------------------------------------------------- */

    for (kstart = n1 ; kstart < npiv ; kstart = kend + 1)
    {

	/* ------------------------------------------------------------------ */
	/* find the end of this Uchain */
	/* ------------------------------------------------------------------ */

	DEBUG4 (("kstart "ID" kend "ID"\n", kstart, kend)) ;
	/* for (kend = kstart ; kend < npiv && Uip [kend+1] > 0 ; kend++) ; */
	kend = kstart ;
	while (kend < npiv && Uip [kend+1] > 0)
	{
	    kend++ ;
	}

	/* ------------------------------------------------------------------ */
	/* scan the whole Uchain to find the pattern of the first row of U */
	/* ------------------------------------------------------------------ */

	k = kend+1 ;
	DEBUG4 (("\nKend "ID" K "ID"\n", kend, k)) ;

	/* ------------------------------------------------------------------ */
	/* start with last row in Uchain of U in Pattern [0..deg-1] */
	/* ------------------------------------------------------------------ */

	if (k == npiv)
	{
	    deg = Numeric->ulen ;
	    if (deg > 0)
	    {
		/* :: make last pivot row of U (singular matrices only) :: */
		for (j = 0 ; j < deg ; j++)
		{
		    Pattern [j] = Numeric->Upattern [j] ;
		}
	    }
	}
	else
	{
	    ASSERT (k >= 0 && k < npiv) ;
	    up = -Uip [k] ;
	    ASSERT (up > 0) ;
	    deg = Uilen [k] ;
	    DEBUG4 (("end of chain for row of U "ID" deg "ID"\n", k-1, deg)) ;
	    ip = (Int *) (Numeric->Memory + up) ;
	    for (j = 0 ; j < deg ; j++)
	    {
		col = *ip++ ;
		DEBUG4 (("  k "ID" col "ID"\n", k-1, col)) ;
		ASSERT (k <= col) ;
		Pattern [j] = col ;
	    }
	}

	/* empty the stack at the bottom of Pattern */
	uhead = n ;

	for (k = kend ; k > kstart ; k--)
	{
	    /* Pattern [0..deg-1] is the pattern of row k of U */

	    /* -------------------------------------------------------------- */
	    /* make row k-1 of U in Pattern [0..deg-1] */
	    /* -------------------------------------------------------------- */

	    ASSERT (k >= 0 && k < npiv) ;
	    ulen = Uilen [k] ;
	    /* delete, and push on the stack */
	    for (j = 0 ; j < ulen ; j++)
	    {
		ASSERT (uhead >= deg) ;
		Pattern [--uhead] = Pattern [--deg] ;
	    }
	    DEBUG4 (("middle of chain for row of U "ID" deg "ID"\n", k, deg)) ;
	    ASSERT (deg >= 0) ;

	    pos = Upos [k] ;
	    if (pos != EMPTY)
	    {
		/* add the pivot column */
		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
		ASSERT (pos >= 0 && pos <= deg) ;
		Pattern [deg++] = Pattern [pos] ;
		Pattern [pos] = k ;
	    }
	}

	/* Pattern [0..deg-1] is now the pattern of the first row in Uchain */

	/* ------------------------------------------------------------------ */
	/* solve using this Uchain, in reverse order */
	/* ------------------------------------------------------------------ */

	DEBUG4 (("Unwinding Uchain\n")) ;
	for (k = kstart ; k <= kend ; k++)
	{

	    /* -------------------------------------------------------------- */
	    /* construct row k */
	    /* -------------------------------------------------------------- */

	    ASSERT (k >= 0 && k < npiv) ;
	    pos = Upos [k] ;
	    if (pos != EMPTY)
	    {
		/* remove the pivot column */
		DEBUG4 (("k "ID" add pivot entry at position "ID"\n", k, pos)) ;
		ASSERT (k > kstart) ;
		ASSERT (pos >= 0 && pos < deg) ;
		ASSERT (Pattern [pos] == k) ;
		Pattern [pos] = Pattern [--deg] ;
	    }

	    up = Uip [k] ;
	    ulen = Uilen [k] ;
	    if (k > kstart)
	    {
		/* concatenate the deleted pattern; pop from the stack */
		for (j = 0 ; j < ulen ; j++)
		{
		    ASSERT (deg <= uhead && uhead < n) ;
		    Pattern [deg++] = Pattern [uhead++] ;
		}
		DEBUG4 (("middle of chain, row of U "ID" deg "ID"\n", k, deg)) ;
		ASSERT (deg >= 0) ;
	    }

	    /* -------------------------------------------------------------- */
	    /* use row k of U */
	    /* -------------------------------------------------------------- */

#ifndef NO_DIVIDE_BY_ZERO
	    /* Go ahead and divide by zero if D [k] is zero. */
#ifdef CONJUGATE_SOLVE
	    /* xk = X [k] / conjugate (D [k]) ; */
	    DIV_CONJ (xk, X [k], D [k]) ;
#else
	    /* xk = X [k] / D [k] ; */
	    DIV (xk, X [k], D [k]) ;
#endif
#else
	    /* Do not divide by zero */
	    if (IS_NONZERO (D [k]))
	    {
#ifdef CONJUGATE_SOLVE
		/* xk = X [k] / conjugate (D [k]) ; */
		DIV_CONJ (xk, X [k], D [k]) ;
#else
		/* xk = X [k] / D [k] ; */
		DIV (xk, X [k], D [k]) ;
#endif
	    }
#endif

	    X [k] = xk ;
	    if (IS_NONZERO (xk))
	    {
		if (k == kstart)
		{
		    up = -up ;
		    xp = (Entry *) (Numeric->Memory + up + UNITS (Int, ulen)) ;
		}
		else
		{
		    xp = (Entry *) (Numeric->Memory + up) ;
		}
		for (j = 0 ; j < deg ; j++)
		{
		    DEBUG4 (("  k "ID" col "ID" value", k, Pattern [j])) ;
		    EDEBUG4 (*xp) ;
		    DEBUG4 (("\n")) ;
#ifdef CONJUGATE_SOLVE
		    /* X [Pattern [j]] -= xk * conjugate (*xp) ; */
		    MULT_SUB_CONJ (X [Pattern [j]], xk, *xp) ;
#else
		    /* X [Pattern [j]] -= xk * (*xp) ; */
		    MULT_SUB (X [Pattern [j]], xk, *xp) ;
#endif
		    xp++ ;
		}
	    }
	}
	ASSERT (uhead == n) ;
    }

#ifndef NO_DIVIDE_BY_ZERO
    for (k = npiv ; k < n ; k++)
    {
	/* This is an *** intentional *** divide-by-zero, to get Inf or Nan,
	 * as appropriate.  It is not a bug. */
	ASSERT (IS_ZERO (D [k])) ;
	/* For conjugate solve, D [k] == conjugate (D [k]), in this case */
	/* xk = X [k] / D [k] ; */
	DIV (xk, X [k], D [k]) ;
	X [k] = xk ;
    }
#endif

#ifndef NDEBUG
    for (j = 0 ; j < n ; j++)
    {
	DEBUG4 (("Utsolve done "ID": ", j)) ;
	EDEBUG4 (X [j]) ;
	DEBUG4 (("\n")) ;
    }
    DEBUG4 (("Utsolve done.\n")) ;
#endif

    return (DIV_FLOPS * ((double) n) + MULTSUB_FLOPS * ((double) Numeric->unz));
}