# Solve the 1-D wave equation using one of several simple schemes,
# with a choice of initial conditions and boundary conditions u = 0 at
# the ends of the grid.  The wave speed is C = 1.
#
# Arguments:
#
#	-c step     time step in units of the Courant limit [1.0]
#	-i initial	0 = sine wave
#               1 = gaussian			[default]
#               2 = sine wave
#               3 damped sine wave
#               4 wave pulse
#	-s scheme   0 = FTCS
#               1 = Lax
#               2 = Lax-Wendroff		[default]

import sys
import getopt
import numpy as np
import matplotlib.pyplot as plt

J    = 201
XMIN = -10.0
XMAX = 10.0
DX   = (XMAX-XMIN)/(J-1.0)

TMAX = 10.
DT   = 1.e-1

DTPLOT	= 0.1
COURANT	= 1.0

whichIC	= 3
SCHEME	= 2
whichBC = 1	            # 1 = absorbing, 2 = reflecting, 3 = periodic
DIRECTION = 1           # 1 = right, -1 = left, 0 = none
SMOOTH = True

def initialize(x):

    u = np.zeros(J)
    r = np.zeros(J)
    s = np.zeros(J)
    
    if whichIC == 1:            # 1: centered gaussian
        u = np.exp(-x**2)
        r = -2*x*u

    elif whichIC == 2:          # 2: centered sine^2 wave
        ix = np.where(np.abs(x) <= 2)[0]
        try:
            ix1,ix2 = ix[0],ix[-1]
            sx = np.sin(np.pi*x[ix1:ix2]/2)
            u[ix1:ix2] = sx**2
            r[ix1:ix2] = np.pi*sx*np.cos(np.pi*x[ix1:ix2]/2)
        except:
            print('error 2')
            pass

    elif whichIC == 3:          # 3: centered damped sine wave
        ex = np.exp(-x**2)
        sx = np.sin(np.pi*x)
        u = ex*sx
        r = ex*(-2*x*sx + np.pi*np.cos(np.pi*x))

    else:                       # 4: "smoothly truncated wave pulse"
        xc = 0.0
        wid2 = 4.0
        xmin = xc - wid2
        xmax = xc + wid2
        ix = np.where(np.abs(x-xc) <= wid2)[0]
        ix0,ix1 = ix[0],ix[-1]
        u[ix0:ix1] = np.sin(np.pi*x[ix0:ix1]/2)
        r[ix0:ix1] = 0.5*np.pi*np.cos(np.pi*x[ix0:ix1]/2)

        # Smooth the edges.

        if SMOOTH:
            delx = 0.4
            n = 3
            mysmooth(x, u, r, xmin, xmax, delx, n)

    # Initial s = -r for a right traveling wave.
    # Initial s = r for a left traveling wave.
    # Initial s = 0 for two oppositely traveling waves.
    # Use copies!  Otherwise s = r just returns a reference to r.

    if DIRECTION > 0:
        s = -r.copy()
    elif DIRECTION < 0:
        s = r.copy()
    else:
        s = 0*r

    return u, r, s

def pfit(x0, x1, u0, u1, n):
    b = (u0/x0**n - u1/x1**n)/(x0 - x1)
    a = u0/x0**n - b*x0
    return a,b

def peval(x, a, b, n):
    return x**n*(a+b*x)
    
def pevalp(x, a, b, n):
    return x**(n-1)*(n*a+(n+1)*b*x)

def smooth1(x, u, r, xlim, delx, n, which):

    # Replace points near one edge with a polynomial.
        
    ix = np.where(np.abs(x-xlim) <= 1.001*delx)[0]
    ix0,ix1 = ix[0],ix[-1]
    if which == 0:
        xf = x[ix0]
        a, b = pfit(x[ix1]-xf, x[ix1-1]-xf, u[ix1], u[ix1-1], n)
    else:
        xf = x[ix1]
        a, b = pfit(x[ix0]-xf, x[ix0+1]-xf, u[ix0], u[ix0+1], n)
    u[ix0:ix1+1] = peval(x[ix0:ix1+1]-xf, a, b, n)
    r[ix0:ix1+1] = pevalp(x[ix0:ix1+1]-xf, a, b, n)
        
def mysmooth(x, u, r, xmin, xmax, delx, n):
    smooth1(x, u, r, xmin, delx, n, 0)
    smooth1(x, u, r, xmax, delx, n, 1)

def apply_boundary_conditions(r, s):

    if whichBC == 1:			# absorbing
        r[0] = r[J-1] = s[0] = s[J-1] = 0.0

    elif whichBC == 2:			# reflecting
        r[0] = 0.0
        r[J-1] = 0.0
        s[0] = s[1]
        s[J-1] = s[J-2]

    else:						# periodic
        r[0] = r[J-2]
        r[J-1] = r[1]
        s[0] = s[J-2]
        s[J-1] = s[1]

def flux(r, s):
    return -s.copy(), -r.copy()

def ftcs_step(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha2 = 0.5*dt/DX

    fr,fs = flux(r, s)

    for j in range(1,J-1):
        u[j] += 0.5*s[j]*dt			# first half time integration of u
        
        r[j] -= alpha2*(fr[j+1]-fr[j-1])
        s[j] -= alpha2*(fs[j+1]-fs[j-1])
        
        u[j] += 0.5*s[j]*dt			# second half time integration of u

def ftcs_step_np(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha2 = 0.5*dt/DX

    fr,fs = flux(r, s)

    u[1:-1] += 0.5*s[1:-1]*dt		# first half time integration of u
        
    r[1:-1] -= alpha2*(fr[2:]-fr[:-2])
    s[1:-1] -= alpha2*(fs[2:]-fs[:-2])
        
    u[1:-1] += 0.5*s[1:-1]*dt		# second half time integration of u

def lax_step(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha2 = 0.5*dt/DX

    rr,ss = r[0],s[0]
    fr,fs = flux(r, s)

    for j in range(1,J-1):
        rj,sj = r[j],s[j]
        u[j] += 0.5*s[j]*dt

        r[j] = 0.5*(rr+r[j+1]) - alpha2*(fr[j+1]-fr[j-1])
        s[j] = 0.5*(ss+s[j+1]) - alpha2*(fs[j+1]-fs[j-1])
        
        rr,ss = rj,sj
        u[j] += 0.5*s[j]*dt

def lax_step_np(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha2 = 0.5*dt/DX
    
    fr,fs = flux(r, s)

    u[1:-1] += 0.5*s[1:-1]*dt

    r[1:-1] = 0.5*(r[2:]+r[:-2]) - alpha2*(fr[2:]-fr[:-2])
    s[1:-1] = 0.5*(s[2:]+s[:-2]) - alpha2*(fs[2:]-fs[:-2])

    u[1:-1] += 0.5*s[1:-1]*dt

def lax_wendroff_step(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha = dt/DX

    rr,ss = r[0],s[0]
    fr,fs = flux(r, s)

    for j in range(1,J-1):
        rj,sj = r[j],s[j]
        u[j] += 0.5*s[j]*dt

        rm12 = 0.5*(r[j]+rr) - 0.5*alpha*(fr[j]-fr[j-1])
        rp12 = 0.5*(r[j+1]+r[j]) - 0.5*alpha*(fr[j+1]-fr[j])

        sm12 = 0.5*(s[j]+ss) - 0.5*alpha*(fs[j]-fs[j-1])
        sp12 = 0.5*(s[j+1]+s[j]) - 0.5*alpha*(fs[j+1]-fs[j])

        r[j] += alpha*(sp12-sm12)
        s[j] += alpha*(rp12-rm12)

        rr,ss = rj,sj
        u[j] += 0.5*s[j]*dt

def lax_wendroff_step_np(u, r, s, dt):

    apply_boundary_conditions(r, s)
    alpha = dt/DX

    fr,fs = flux(r, s)

    u[1:-1] += 0.5*s[1:-1]*dt

    rm12 = 0.5*(r[1:-1]+r[:-2]) - 0.5*alpha*(fr[1:-1]-fr[:-2])
    rp12 = 0.5*(r[2:]+r[1:-1]) - 0.5*alpha*(fr[2:]-fr[1:-1])

    sm12 = 0.5*(s[1:-1]+s[:-2]) - 0.5*alpha*(fs[1:-1]-fs[:-2])
    sp12 = 0.5*(s[2:]+s[1:-1]) - 0.5*alpha*(fs[2:]-fs[1:-1])

    r[1:-1] += alpha*(sp12-sm12)
    s[1:-1] += alpha*(rp12-rm12)

    u[1:-1] += 0.5*s[1:-1]*dt

def display_data(x, u, t, scheme, ymin, ymax):
    if scheme == 0:
        title = 'FTCS'
    elif scheme == 1:
        title = 'Lax'
    else:
        title = 'Lax-Wendroff'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('u')
    plt.xlim(XMIN, XMAX)
    plt.ylim(ymin, ymax)
    u = np.clip(u, -1.e10, 1.e10)

    plt.plot(x, u, zorder=2)
    a = initialize(x)[0]
    if DIRECTION > 0:
        plt.plot(x+t, a, zorder=1)
    elif DIRECTION < 0:
        plt.plot(x-t, a, zorder=1)
    else:
        a = 0.5*a
        plt.plot(x+t, a, zorder=1)
        plt.plot(x-t, a, zorder=1)
    
    if t == 0.0:
        plt.pause(1.0)
    else:
        plt.pause(0.001)

def main(argv):
    global whichIC, whichBC, J, DX, TMAX, DIRECTION, SMOOTH

    opts, args = getopt.getopt(sys.argv[1:], "b:c:d:i:j:ns:St:")

    scheme = SCHEME
    courant = COURANT
    tmax = TMAX
    numpy = True
    DIRECTION = 1
    SMOOTH = True

    for o, a in opts:
        if o == "-b":           whichBC = int(a)
        elif o == "-c":         courant = float(a)
        elif o == "-d":         DIRECTION = int(a)
        elif o == "-i":         whichIC = int(a)
        elif o == "-j":         J = int(a)
        elif o == "-n":         numpy = not numpy
        elif o == "-s":         scheme = int(a)
        elif o == "-S":         SMOOTH = not SMOOTH
        elif o == "-t":         tmax = float(a)
        else:
            print("unexpected argument", o)

    if scheme == 0:
        if numpy:
            step = ftcs_step_np
        else:
            step = ftcs_step
    elif scheme == 1:
        if numpy:
            step = lax_step_np
        else:
            step = lax_step
    else:
        if numpy:
            step = lax_wendroff_step_np
        else:
            step = lax_wendroff_step2

    print('J =', J, ' BC =', whichBC, ' IC =', whichIC, ' scheme =', scheme, \
          ' courant =', courant, ' numpy =', numpy)

    DX = (XMAX-XMIN)/(J-1.0)
    x = np.linspace(XMIN, XMAX, J)
    u,r,s = initialize(x)

    t = 0.0
    dt = DT
    if scheme > 0: dt = courant*DX    
    dtplot = DTPLOT
    if dtplot > dt: dtplot = dt
    tplot = dtplot

    ymax = 1.25
    ymin = -0.25
    if whichIC == 3 or whichIC == 4: ymin = -1.25

    display_data(x, u, t, scheme, ymin, ymax)

    while t < tmax-0.5*dt:
        if t+dt > TMAX: dt = TMAX + 0.0001*dt - t
        step(u, r, s, dt)
        t += dt
        if t > tplot-0.5*DT or t >= TMAX:
            display_data(x, u, t, scheme, ymin, ymax)
            while t > tplot-0.5*dt: tplot += dtplot

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
