
# Solve the advection equation using a simple Lax differencing
# scheme.  Select initial conditions with the -i argument (1, 2, 3;
# default 1).
#
# Usage:	python ex6.2.py -i [123]

import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J	=  201
XMIN	= -10.0
XMAX	= 10.0
DX	= (XMAX-XMIN)/(J-1.0)

TMAX	= 10.0
DT	= 0.99e-1

DTPRINT	= 0.1
ic	= 2

def initial(x):
    global ic
    if ic == 1:				# 1: gaussian
        return math.exp(-pow(x+7,2));
    elif ic == 2:			# 2: sine wave
        if x < -2 or x > 0:
            return 0
        else:
            return math.sin(math.pi*x)
    else:				# 3: square wave
        if x < -2 or x > 0:
            return 0
        else:
            return 1

def initialize(u):
    for j in range(J):
        x = XMIN + j*DX
        u[j] = initial(x)

def lax_wendroff_step(u, dt):
    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    alpha = dt/DX
    uu = u[0]				# old u[j-1]
    for j in range(J-1):
        uj = u[j]
        u[j] = uj - 0.5*alpha*(u[j+1] + uj - alpha*(u[j+1]-uj)
                               - (uj + uu) + alpha*(uj - uu))
        uu = uj

def display_data(x, u, t):
    title = 'Lax-Wendroff'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(-1.25, 1.25)
    a = np.zeros(J)
    for j in range(J):
        u[j] = min(u[j], 1.e10)
        u[j] = max(u[j], -1.e10)
        a[j] = initial(x[j]-t)
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    global ic
    t = 0
    dt = DT
    dtprint = DTPRINT
    tprint = dtprint

    opts, args = getopt.getopt(sys.argv[1:], "i:")

    for o, a in opts:
        if o == "-i":
            ic = int(a)
        else:
            print("unexpected argument", o)

    print('IC =', ic)

    x = np.arange(XMIN, XMAX+DX, DX)
    u = np.zeros(J)
    initialize(u)

    display_data(x, u, t)

    while t < TMAX-0.5*dt:
        lax_wendroff_step(u, dt)
        t += dt
        if t > tprint-0.5*DT:
            display_data(x, u, t)
            tprint += dtprint

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
