import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J       =  201
XMIN    = -10.0
XMAX    = 10.0
DX      = (XMAX-XMIN)/(J-1.0)

TMAX    = 10.0
DT      = 0.1					# large!

D       = 1.0
ALPHA   = D*DT/DX**2

DTPLOT = 0.1

def fund(x, t):
    return math.exp(-x**2/(4*D*t))/math.sqrt(4*math.pi*D*t)
    
def initial(x, t):
    return fund(x, t)

def initialize(u, t):
    for j in range(J):
        x = XMIN + j*DX
        u[j] = initial(x, t)

# Tridiagonal matrix solver.  Nonzero diagonals of the matrix are
# represented by arrays a, b, and c (see Numerical Recipes).  Array r
# is the right-hand side.  On exit, array u contains the solution
# vector.  All arrays start at 0.

def tridiag(a, b, c, r, n):

    u = np.zeros(n)
    gam = np.zeros(n)
    bet = b[0]
    u[0] = r[0] / bet

    for j in range(1,n):
        gam[j] = c[j-1]/bet
        bet = b[j] - a[j]*gam[j]
        u[j] = (r[j]-a[j]*u[j-1]) / bet

    for j in range(n-2,-1,-1):
        u[j] -= gam[j+1]*u[j+1]

    return u

######################################
def ftcs_step(a, b, c, u, dt):
    u[0] = 0.0				# boundary conditions
    u[-1] = 0.0
    r = u.copy()
    return tridiag(a, b, c, r, J)
######################################

def display_data(x, u, t):
    title = 'FTCS'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(0.0, 0.3)
    a = np.zeros(J)
    for j in range(J):
        u[j] = min(u[j], 1.e10)
        u[j] = max(u[j], -1.e10)
        a[j] = initial(x[j], t)
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    t = 1.0
    dt = DT
    dtplot = DTPLOT
    tplot = dtplot

    if len(sys.argv) > 1: dt = float(sys.argv[1])
    ALPHA = D*dt/DX**2

    x = np.linspace(XMIN, XMAX, J)
    u = np.zeros(J)
    initialize(u, t)

    display_data(x, u, t)


    
    ######################################
    ### Set up a, bb, c arrays of length J.
    
    a = np.zeros(J)
    b = np.zeros(J)
    c = a.copy()

    # Apply boundary conditions:

    b[0] = 1.0
    c[0] = 0.0
    a[J-1] = 0.0
    b[J-1] = 1.0

    # Note that the vectors a, b, and c are CONSTANT for fixed time step.
    ######################################


    
    while t < TMAX-0.5*dt:
        u = ftcs_step(a, b, c, u, dt)
        t += dt
        if t > tplot-0.5*DT:
            display_data(x, u, t)
            tplot += dtplot

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
