import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J       =  301
XMIN    = -15.0
XMAX    = 15.0
DX      = (XMAX-XMIN)/(J-1.0)

TMAX    = 10.0
DT      = 1.e-2

D       = 1.0
ALPHA   = D*DT/DX**2

DTPLOT  = 0.1

def fund(x, t):
    return math.exp(-x**2/(4*D*t))/math.sqrt(4*math.pi*D*t)
    
def initial(x, t):
    return fund(x, t)

def initialize(u, t):
    for j in range(J):
        x = XMIN + j*DX
        u[j] = initial(x, t)

def ftcs_step(u, dt):
    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    alpha = D*dt/DX**2

    uu = u[0]				# old u[j-1]
    for j in range(J-1):
        uj = u[j]
        u[j] += alpha*(u[j+1]-2*u[j]+uu)
        uu = uj

def display_data(x, u, t):
    title = 'FTCS explicit'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(0.0, 0.3)
    a = np.zeros(J)
    for j in range(J):
        u[j] = min(u[j], 1.e10)
        u[j] = max(u[j], -1.e10)
        a[j] = initial(x[j], t)
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    t = 1.0
    dt = DT
    dtplot = DTPLOT
    tplot = dtplot

    if len(sys.argv) > 1: dt = float(sys.argv[1])
    ALPHA = D*dt/DX**2
    print('dt =', dt, 'ALPHA =', ALPHA)

    x = np.linspace(XMIN, XMAX, J)
    u = np.zeros(J)
    initialize(u, t)

    display_data(x, u, t)

    while t < TMAX-0.5*dt:
        ftcs_step(u, dt)
        t += dt
        if t > tplot-0.5*dt:
            display_data(x, u, t)
            tplot += dtplot

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
