# Solve the 1-D advection equation using a simple FTCS (-s 0), Lax (1;
# DEFAULT), or Lax-Wendroff (2) differencing scheme.  Select initial
# conditions with the -c argument (1, 2, 3; default 1).  Use the
# Courant factor (-c, default 1.0) to determine the time step in the
# Lax/Lax-Wendroff case.
#
# Usage:	python advection1d.py -i [123] -s [012] -c [cour] 

import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

# Settable parameters:

J       =  201
COURANT	= 1.0
whichIC	= 2
SCHEME	= 1
NUMPY   = True

XMIN = -10.0
XMAX = 10.0
DX   = (XMAX-XMIN)/(J-1.0)

TMAX = 10.0
DT   = 1.e-1
DTPLOT	= 0.1

def initial(x):

    u = 0.0
    if whichIC == 1:			# 1: gaussian
        u = math.exp(-(x+7)**2)
    elif whichIC == 2:			# 2: sine wave
        if x < -2 or x > 0:
            return 0
        else:
            return math.sin(math.pi*x)
    elif whichIC == 3:			# 3: square wave
        if x >= -8 or x <= -6:
            u = 1.0
    else:						# something else
        if x >= -8 or x <= -6:
            u = math.sin(math.pi*x/2)
    return u
        
def initialize(u):
    for j in range(J):
        x = XMIN + j*DX
        u[j] = initial(x)

def v(u):
    return 1.  + 0.05*u

def ftcs_step(u, dt):
    alpha2 = 0.5*dt/DX

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    uu = u[0]				# old u[j-1]
    for j in range(1,J-1):
        uj = u[j]
        u[j] -= alpha2*(u[j+1]-uu)
        uu = uj

def ftcs_step_np(u, dt):
    alpha2 = 0.5*dt/DX

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    u[1:-1] -= alpha2*(u[2:]-u[:-2])

def lax_step(u, dt):
    alpha = dt/DX

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    uu = u[0]				# old u[j-1]
    for j in range(1,J-1):
        uj = u[j]
        u[j] = 0.5*(uu+u[j+1]) - 0.5*(v(uj)*alpha)*(u[j+1]-uu)
        uu = uj

def lax_step_np(u, dt):
    alpha = dt/DX

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    u[1:-1] = 0.5*(1-alpha)*u[2:] + 0.5*(1+alpha)*u[:-2]

def lax_wendroff_step(u, dt):

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    uu = u[0]				# old u[j-1]
    for j in range(1,J-1):
        uj = u[j]
        alpha = v(uu)*dt/DX

        um12 = 0.5*(uu+uj) - 0.5*alpha*(uj-uu)
        up12 = 0.5*(uj+u[j+1]) - 0.5*alpha*(u[j+1]-uj)
        u[j] -= alpha*(up12-um12)
            
        uu = uj

def lax_wendroff_step_np(u, dt):
    alpha = dt/DX

    u[0] = 0.0				# boundary conditions
    u[J-1] = 0.0

    u[1:-1] -= 0.5*alpha*(u[2:]-u[:-2] - alpha*(u[2:]-2*u[1:-1]+u[:-2]))

def display_data(x, u, t, alpha, scheme, ymin, ymax):
    if scheme == 0:
        title = 'FTCS'
    elif scheme == 1:
        title = 'Lax'
    else:
        title = 'Lax-Wendroff'

    title += ' (alpha = %.2f)'%(alpha)
    
    if numpy:
        title += ' (numpy)'
        
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
        
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(ymin, ymax)
    u = np.clip(u, -1.e10, 1.e10)
    a = np.zeros(J)
    for j in range(J):
        a[j] = initial(x[j]-t)
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    global whichIC, J, DX, numpy

    opts, args = getopt.getopt(sys.argv[1:], "c:i:j:ns:")

    scheme = SCHEME
    courant = COURANT
    numpy = True

    for o, a in opts:
        if o == "-c":
            courant = float(a)
        elif o == "-i":
            whichIC = int(a)
        elif o == "-j":
            J = int(a)
        elif o == "-n":
            numpy = not numpy
        elif o == "-s":
            scheme = int(a)
        else:
            print("unexpected argument", o)

    if scheme == 0:
        if numpy:
            step = ftcs_step_np
        else:
            step = ftcs_step
    elif scheme == 1:
        if numpy:
            step = lax_step_np
        else:
            step = lax_step
    else:
        if numpy:
            step = lax_wendroff_step_np
        else:
            step = lax_wendroff_step

    print('J =', J, ' IC =', whichIC, ' scheme =', scheme, \
          ' courant =', courant, ' numpy =', numpy)

    DX = (XMAX-XMIN)/(J-1.0)
    x = np.linspace(XMIN, XMAX, J)
    u = np.zeros(J)
    initialize(u)

    t = 0.0
    dt = DT
    if scheme > 0: dt = courant*DX
    dtplot = DTPLOT
    if dtplot > dt: dtplot = dt
    tplot = dtplot

    ymax = 1.25
    ymin = -0.25
    if whichIC == 2 or whichIC == 4: ymin = -1.25

    alpha = dt/DX
    display_data(x, u, t, alpha, scheme, ymin, ymax)

    while t < TMAX-0.5*dt:
        if t+dt > TMAX: dt = TMAX + 0.0001*dt - t
        step(u, dt)
        t += dt
        if t > tplot-0.5*dt or t >= TMAX:
            display_data(x, u, t, alpha, scheme, ymin, ymax)
            while t > tplot-0.5*dt: tplot += dtplot

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
