
# Solve the advection equation using a simple Lax differencing
# scheme.  Select initial conditions with the -i argument (1, 2, 3;
# default 1).
#
# Usage:        python ex6.2.py -i [123]

import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J       =  201
XMIN    = -10.0
XMAX    = 10.0
DX      = (XMAX-XMIN)/(J-1.0)

TMAX    = 10.0
DT      = 1.01e-1

DTPRINT = 0.1
ic      = 2

def initial(x):
    global ic
    if ic == 1:					# 1: gaussian
        return math.exp(-pow(x+7,2));
    elif ic == 2:				# 2: sine wave
        if x < -2 or x > 0:
            return 0
        else:
            return math.sin(math.pi*x)
    else:					# 3: square wave
        if x < -2 or x > 0:
            return 0
        else:
            return 1

def initialize(u):
    for j in range(J):
        x = XMIN + j*DX
        u[j] = initial(x)

def lax_step(u, dt):
    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    uu = u[0]				# old u[j-1]
    for j in range(J-1):
        uj = u[j]			# use uu wherever u[j-1] appears	
        u[j] = 0.5*(uu+u[j+1]) - 0.5*(dt/DX)*(u[j+1]-uu)
        uu = uj

def display_data(x, u, t):
    title = 'Lax'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(-1.25, 1.25)
    a = np.zeros(J)
    for j in range(J):
        u[j] = min(u[j], 1.e10)
        u[j] = max(u[j], -1.e10)
        a[j] = initial(x[j]-t)
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    global ic
    t = 0
    dt = DT
    dtprint = DTPRINT
    tprint = dtprint

    try:
        opts, args = getopt.getopt(sys.argv[1:], "i:")
    except getopt.GetoptError as err:
        print(str(err))
        sys.exit(1)

    for o, a in opts:
        if o == "-i":
            ic = int(a)
        else:
            print("unexpected argument", o)

    print('IC =', ic)

    x = np.arange(XMIN, XMAX+DX, DX)
    u = np.zeros(J)
    initialize(u)

    display_data(x, u, t)

    while t < TMAX-0.5*dt:
        lax_step(u, dt)
        t += dt
        if t > tprint-0.5*DT:
            display_data(x, u, t)
            tprint += dtprint

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
