
# Tridiagonal matrix solver.  Nonzero diagonals of the matrix are
# represented by numpy arrays a, b, and c (see Numerical Recipes).
# Numpy array r is the right-hand side.  The solution vector is
# returned.  All arrays start at index 0.

import sys
import numpy as np

def err_exit(n):
    print('Error', n, 'in tridiag.')
    sys.exit(n)

def tridiag(a, b, c, r):

    n = len(r)
    u = np.zeros(n)
    gam = np.zeros(n)
    bet = b[0]
    u[0] = r[0] / bet

    if b[0] == 0.0: err_exit(1)
    bet = b[0]
    u[0] = r[0] / bet

    # Reduce to upper triangular.
    
    for j in range(1,n):
        gam[j] = c[j-1]/bet
        bet = b[j] - a[j]*gam[j]
        if bet == 0.0: err_exit(2)
        u[j] = (r[j]-a[j]*u[j-1]) / bet

    # Solve by back-substitution.
    
    for j in range(n-2,-1,-1):
        u[j] -= gam[j+1]*u[j+1]

    return u
