# Solve the 1-D wave equation using one of several simple schemes,
# with a choice of initial conditions and boundary conditions u = 0 at
# the ends of the grid.  The wave speed is C = 1.
#
# Numpy version, with smoothing of ICs (option 4 only).
#
# Arguments:
#
#	-c step 	time step in units of the Courant limit [1.0]
#	-i initial	0 = sine wave
#				1 = gaussian			[default]
#				2 = sine wave
#				3 damped sine wave
# 				4 wave pulse
#	-s scheme	0 = FTCS
#				1 = Lax					[default]
#				2 = Lax-Wendroff

import sys
import getopt
import numpy as np
import matplotlib.pyplot as plt
from mysmooth import *

J	=  201
XMIN	= -10.0
XMAX	= 10.0
DX	= (XMAX-XMIN)/(J-1.0)

TMAX	= 10.0
DT	= 1.e-1

DTPRINT	= 0.1
COURANT	= 1.0

whichIC	= 1
SCHEME	= 1
DIRECTION = 1
SMOOTH = False

def initialize(x):

    u = np.zeros(J)
    r = np.zeros(J)
    s = np.zeros(J)
    
    if whichIC == 1:			# 1: centered gaussian
        u = np.exp(-x**2)
        r = -2*x*u

    elif whichIC == 2:			# 2: centered sine^2 wave
        ix = np.where(np.abs(x) <= 2)[0]
        try:
            ix1,ix2 = ix[0],ix[-1]
            sx = np.sin(np.pi*x[ix1:ix2]/2)
            u[ix1:ix2] = sx**2
            r[ix1:ix2] = np.pi*sx*np.cos(np.pi*x[ix1:ix2]/2)
        except:
            print('error 2')
            pass

    elif whichIC == 3:			# 3: centered damped sine wave
        ex = np.exp(-x**2)
        sx = np.sin(np.pi*x)
        u = ex*sx
        r = ex*(-2*x*sx + np.pi*np.cos(np.pi*x))

    else:						# 4: "smoothly truncated wave pulse"
        xc = 0.0
        wid2 = 4.0
        xmin = xc - wid2
        xmax = xc + wid2
        ix = np.where(np.abs(x-xc) <= wid2)[0]
        ix0,ix1 = ix[0],ix[-1]
        u[ix0:ix1] = np.sin(np.pi*x[ix0:ix1]/2)
        r[ix0:ix1] = 0.5*np.pi*np.cos(np.pi*x[ix0:ix1]/2)

        # Smooth the edges.

        if SMOOTH:
            delx = 0.4
            n = 3
            mysmooth(x, u, r, xmin, xmax, delx, n)

    # Initial s = -r for a right traveling wave.
    # Initial s = r for a left traveling wave.
    # Initial s = 0 for two oppositely traveling waves.
    # Use copies!  Otherwise s = r just returns a reference to r.

    if DIRECTION > 0:
        s = -r.copy()
    elif DIRECTION < 0:
        s = r.copy()
    else:
        s = 0*r

    return u, r, s

def lax_step(u, r, s, dt):
    rr, ss, r[0], s[0], r[J-1], s[J-1] = 6*[0.0]
    alpha = 0.5*(dt/DX)

    for j in range(J-1):
        rj = r[j]
        sj = s[j]
        u[j] += 0.5*sj*dt

        r[j] = 0.5*(rr+r[j+1]) + alpha*(s[j+1]-ss)
        s[j] = 0.5*(ss+s[j+1]) + alpha*(r[j+1]-rr)
        
        rr = rj
        ss = sj
        u[j] += 0.5*s[j]*dt

def lax_wendroff_step(u, r, s, dt):
    rr, ss, r[0], s[0], r[J-1], s[J-1] = 6*[0.0]	# boundary conditions
    alpha = dt/DX

    for j in range(J-1):
        rj = r[j]
        sj = s[j]
        u[j] += 0.5*sj*dt

        rm12 = 0.5*(r[j]+rr) + 0.5*alpha*(s[j]-ss)
        sm12 = 0.5*(s[j]+ss) + 0.5*alpha*(r[j]-rr)

        rp12 = 0.5*(r[j+1]+r[j]) + 0.5*alpha*(s[j+1]-s[j])
        sp12 = 0.5*(s[j+1]+s[j]) + 0.5*alpha*(r[j+1]-r[j])

        r[j] += alpha*(sp12-sm12)
        s[j] += alpha*(rp12-rm12)

        rr = rj
        ss = sj
        u[j] += 0.5*s[j]*dt

def lax_wendroff_step_np(u, r, s, dt):

    rr, ss, r[0], s[0], r[J-1], s[J-1] = 6*[0.0]	# boundary conditions
    alpha = dt/DX

    # Define fluxes:

    fr = -s.copy()
    fs = -r.copy()

    u[1:-1] += 0.5*s[1:-1]*dt

    rm12 = 0.5*(r[1:-1]+r[:-2]) - 0.5*alpha*(fr[1:-1]-fr[:-2])
    rp12 = 0.5*(r[2:]+r[1:-1]) - 0.5*alpha*(fr[2:]-fr[1:-1])

    sm12 = 0.5*(s[1:-1]+s[:-2]) - 0.5*alpha*(fs[1:-1]-fs[:-2])
    sp12 = 0.5*(s[2:]+s[1:-1]) - 0.5*alpha*(fs[2:]-fs[1:-1])

    r[1:-1] += alpha*(sp12-sm12)
    s[1:-1] += alpha*(rp12-rm12)

    u[1:-1] += 0.5*s[1:-1]*dt

def display_data(x, u, t, scheme, ymin, ymax):
    if scheme == 0:
        title = 'FTCS'
    elif scheme == 1:
        title = 'Lax'
    else:
        title = 'Lax-Wendroff'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.xlim(XMIN, XMAX)
    plt.ylim(ymin, ymax)
    u = np.clip(u, -1.e10, 1.e10)

    plt.plot(x, u, zorder=2)
    a = initialize(x)[0]
    if DIRECTION > 0:
        plt.plot(x+t, a, zorder=1)
    elif DIRECTION < 0:
        plt.plot(x-t, a, zorder=1)
    else:
        a = 0.5*a
        plt.plot(x+t, a, zorder=1)
        plt.plot(x-t, a, zorder=1)
    
    if t == 0.0:
        plt.pause(1.0)
    else:
        plt.pause(0.001)

def main(argv):
    global whichIC, J, DX, SMOOTH

    t = 0.0
    dt = DT
    dtprint = DTPRINT
    tprint = dtprint
    scheme = SCHEME
    courant = COURANT

    opts, args = getopt.getopt(sys.argv[1:], "c:i:j:s:S")

    for o, a in opts:
        if o == "-c":
            courant = float(a)
        elif o == "-i":
            whichIC = int(a)
        elif o == "-j":
            J = int(a)
        elif o == "-s":
            scheme = int(a)
        elif o == "-S":
            SMOOTH = not SMOOTH
        else:
            print("unexpected argument", o)

    if scheme == 0:
        step = None
    elif scheme == 1:
        step = lax_step
    else:
        step = lax_wendroff_step_np

    DX = (XMAX-XMIN)/(J-1.0)
    if scheme > 0: dt = courant*DX
    print('IC =', whichIC, ' scheme =', scheme, ' courant =', courant)

    x = np.linspace(XMIN, XMAX, J)
    u, r, s = initialize(x)

    ymax = 1.25
    ymin = -0.25
    if whichIC == 3 or whichIC == 4: ymin = -1.25

    display_data(x, u, t, scheme, ymin, ymax)

    while t < TMAX-0.5*dt:
        step(u, r, s, dt)
        t += dt
        if t > tprint-0.5*DT:
            display_data(x, u, t, scheme, ymin, ymax)
            tprint += dtprint

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
