# Solve the 1-D wave equation using one of several simple schemes,
# with a choice of initial conditions and boundary conditions u = 0 at
# the ends of the grid.  The wave speed is C = 1.
#
# Arguments:
#
#	-c step 	time step in units of the Courant limit [1.0]
#	-i initial	0 = sine wave
#			1 = gaussian			[default]
#			2 = step function
#	-s scheme	0 = FTCS
#			1 = Lax				[default]
#			2 = Lax-Wendroff		(not implemented)

import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J	=  201
XMIN	= -10.0
XMAX	= 10.0
DX	= (XMAX-XMIN)/(J-1.0)

TMAX	= 10.0
DT	= 1.e-2

DTPRINT	= 0.1
COURANT	= 1.0

whichIC	= 1
SCHEME	= 1

def initial(x):
    global whichIC
    
    u,r = 2*[0.0]

    if whichIC == 1:			# 1: centered gaussian
        u = math.exp(-x**2)
        r = -2*x*u
    elif whichIC == 2:			# 2: centered sine wave
        if x >= -2 and x <= 2:
            sx = math.sin(math.pi*x/2) 
            u = sx**2
            r = math.pi*sx*math.cos(math.pi*x/2)
    elif whichIC == 3:			# 3: centered damped sine wave
        ex = math.exp(-x**2)
        sx = math.sin(math.pi*x)
        u = ex*sx
        r = ex*(-2*x*sx + math.pi*math.cos(math.pi*x))
    else:						# 4: truncated sine wave
        if x >= -2 and x <= 2:
            u = math.sin(math.pi*x/2)
            r = 0.5*math.pi*math.cos(math.pi*x/2)

    # Initial s = -r for a right traveling wave.
    # Initial s = r for a left traveling wave.
    # Initial s = 0 for two oppositely traveling waves.

    s = r

    return u, r, s

def initialize(u, r, s):
    global J, DX

    # Set up u and its spatial (r) and temporal (s) derivatives.

    for j in range(J):
        x = XMIN + j*DX
        u[j],r[j],s[j] = initial(x)

def lax_step(u, r, s, dt):
    global J, DX
    alpha = 0.5*(dt/DX)

    rr, ss, r[0], s[0], r[J-1], s[J-1] = 6*[0.0]
    for j in range(J-1):
        rj = r[j]
        sj = s[j]
        u[j] += 0.5*sj*dt

        r[j] = 0.5*(rr+r[j+1]) + alpha*(s[j+1]-ss)
        s[j] = 0.5*(ss+s[j+1]) + alpha*(r[j+1]-rr)
        
        rr = rj
        ss = sj
        u[j] += 0.5*s[j]*dt

def lax_wendroff_step(u, r, s, dt):
    global J, DX
    alpha = dt/DX

    rr, ss, r[0], s[0], r[J-1], s[J-1] = 6*[0.0]
    for j in range(J-1):
        rj = r[j]
        sj = s[j]
        u[j] += 0.5*sj*dt

        rm12 = 0.5*(r[j]+rr) + 0.5*alpha*(s[j]-ss)
        sm12 = 0.5*(s[j]+ss) + 0.5*alpha*(r[j]-rr)

        rp12 = 0.5*(r[j+1]+r[j]) + 0.5*alpha*(s[j+1]-s[j])
        sp12 = 0.5*(s[j+1]+s[j]) + 0.5*alpha*(r[j+1]-r[j])

        r[j] += alpha*(sp12-sm12)
        s[j] += alpha*(rp12-rm12)

        rr = rj
        ss = sj
        u[j] += 0.5*s[j]*dt

def display_data(x, u, t, scheme, ymin, ymax):
    if scheme == 0:
        title = 'FTCS'
    elif scheme == 1:
        title = 'Lax'
    else:
        title = 'Lax-Wendroff'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.ylim(ymin, ymax)
    a = np.zeros(J)
    for j in range(J):
        u[j] = min(u[j], 1.e10)
        u[j] = max(u[j], -1.e10)
        a[j] = initial(x[j]-t)[0]
    plt.plot(x, u, zorder=2)
    plt.plot(x, a, zorder=1)
    plt.pause(0.001)

def main(argv):
    global whichIC, J, DX

    t = 0.0
    dt = DT
    dtprint = DTPRINT
    tprint = dtprint
    scheme = SCHEME
    courant = COURANT

    opts, args = getopt.getopt(sys.argv[1:], "c:i:j:s:")

    for o, a in opts:
        if o == "-c":
            courant = float(a)
        elif o == "-i":
            whichIC = int(a)
        elif o == "-j":
            J = int(a)
        elif o == "-s":
            scheme = int(a)
        else:
            print("unexpected argument", o)

    if scheme == 0:
        step = None
    elif scheme == 1:
        step = lax_step
    else:
        step = lax_wendroff_step

    DX = (XMAX-XMIN)/(J-1.0)
    if scheme > 0: dt = courant*DX
    print('IC =', whichIC, ' scheme =', scheme, ' courant =', courant)

    x = np.arange(XMIN, XMAX+DX, DX)
    u = np.zeros(J)
    r = np.zeros(J)
    s = np.zeros(J)
    initialize(u, r, s)

    ymax = 1.25
    ymin = -0.25
    if whichIC == 3 or whichIC == 4: ymin = -1.25

    display_data(x, u, t, scheme, ymin, ymax)

    while t < TMAX-0.5*dt:
        step(u, r, s, dt)
        t += dt
        if t > tprint-0.5*DT:
            display_data(x, u, t, scheme, ymin, ymax)
            tprint += dtprint

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
