# Solve the 2-D wave equation using a simple FTCS (-s 0), Lax (1),
# Lax-Wendroff (2; DEFAULT) differencing scheme.  Select initial
# conditions with the -c argument (1, 2, 3, 4; default 2).  Use the
# Courant factor (-c, default 1.) to determine the time step in the
# Lax/Lax-Wendroff case.
#
# Usage:	python wave2d.py -i [1234] -s [012] -c [cour] 

import sys
import math
import getopt
import numpy as np
import matplotlib.pyplot as plt

J	=  101
XMIN	= -10.0
XMAX	= 10.0
DX	= (XMAX-XMIN)/(J-1.0)
YMIN	= -10.0
YMAX	= 10.0
DY	= (YMAX-YMIN)/(J-1.0)

X0	= 4.0
Y0	= -2.0
TMAX = 12.0
DT	= 1.e-2

whichIC	= 2
SCHEME	= 2

DTPLOT	= 0.1
COURANT	= 0.9				# < 1 to suppress nonlinear effects

def initial(x, y):

    u = 0.0
    rx = 0.0
    ry = 0.0
    s = 0.0

    if not ripple:
        r2 = (x-X0)**2+(y-Y0)**2
        r = math.sqrt(r2)
        if whichIC == 1:			# 1: circular gaussian
            u = math.exp(-r2/4);
            rx = -0.5*x*u
            ry = -0.5*y*u
        elif whichIC == 2:			# 2: circular sine wave
            if r <= 2:
                u = math.sin(math.pi*r)
                if r > 0:
                    rx = math.pi*(x-X0)*math.cos(math.pi*r)/r
                    ry = math.pi*(y-Y0)*math.cos(math.pi*r)/r
        elif whichIC == 3:			# 3: plane sine wave
            if abs(x+2) <= 2:
                u = math.sin(math.pi*x)
                rx = math.pi*math.cos(math.pi*x)
                ry = 0.0
                s = -rx
        elif whichIC == 4:			# 4: plane gaussian
            u = math.exp(-(x+5)**2/2)
            rx = -(x+5)*u
            ry = 0.0
            s = -rx

    return u,rx,ry,s
        
def initialize(u, rx, ry, s):

    for j in range(J):
        x = XMIN + j*DX
        for k in range(J):
            y = YMIN + k*DY
            u[j,k],rx[j,k],ry[j,k],s[j,k] = initial(x,y)

def ftcs_step(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)
    s0 = np.copy(s)

    alphax = 0.5*(dt/DX)
    alphay = 0.5*(dt/DY)
    for j in range(1,J-1):
        for k in range(1,J-1):
            u[j,k] += 0.5*s0[j,k]*dt	# first half time integration of u

            rx[j,k] += alphax*(s0[j+1,k]-s0[j-1,k])
            ry[j,k] += alphay*(s0[j,k+1]-s0[j,k-1])
            s[j,k]  += alphax*(rx0[j+1,k]-rx0[j-1,k]) \
                         + alphay*(ry0[j,k+1]-ry0[j,k-1])

            u[j,k] += 0.5*s[j,k]*dt		# second half time integration of u

def ftcs_step_np(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    alphax = 0.5*(dt/DX)
    alphay = 0.5*(dt/DY)

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)

    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # first half time integration of u

    rx[1:-1,1:-1] += alphax*(s[2:,1:-1]-s[:-2,1:-1])
    ry[1:-1,1:-1] += alphay*(s[1:-1,2:]-s[1:-1,:-2])
    s[1:-1,1:-1]  += alphax*(rx0[2:,1:-1]-rx0[:-2,1:-1]) \
                       + alphay*(ry0[1:-1,2:]-ry0[1:-1,:-2])

    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # second half time integration of u

def average(u, j, k):
    return 0.25*(u[j-1,k]+u[j+1,k]+u[j,k-1]+u[j,k+1])

def lax_step(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)
    s0 = np.copy(s)

    alphax = 0.5*(dt/DX)
    alphay = 0.5*(dt/DY)
    for j in range(1,J-1):
        for k in range(1,J-1):
            u[j,k] += 0.5*s0[j,k]*dt	# first half time integration of u

            rx[j,k] = average(rx0,j,k) + alphax*(s0[j+1,k]-s0[j-1,k])
            ry[j,k] = average(ry0,j,k) + alphay*(s0[j,k+1]-s0[j,k-1])
            s[j,k]  = average(s0,j,k) + alphax*(rx0[j+1,k]-rx0[j-1,k]) \
                        + alphay*(ry0[j,k+1]-ry0[j,k-1])
            
            u[j,k] += 0.5*s[j,k]*dt	# second half time integration of u

def average_np(u):
    return 0.25*(u[:-2,1:-1]+u[2:,1:-1]+u[1:-1,:-2]+u[1:-1,2:])

def lax_step_np(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)

    alphax = 0.5*(dt/DX)
    alphay = 0.5*(dt/DY)
    
    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # first half time integration of u

    rx[1:-1,1:-1] = average_np(rx) + alphax*(s[2:,1:-1]-s[:-2,1:-1])
    ry[1:-1,1:-1] = average_np(ry) + alphay*(s[1:-1,2:]-s[1:-1,:-2])
    s[1:-1,1:-1]  = average_np(s) \
                      + alphax*(rx0[2:,1:-1]-rx0[:-2,1:-1]) \
                      + alphay*(ry0[1:-1,2:]-ry0[1:-1,:-2])
            
    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # second half time integration of u

def lax_wendroff_step(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)
    s0 = np.copy(s)

    alphax = dt/DX
    alphay = dt/DY
    for j in range(1,J-1):
        for k in range(1,J-1):
            u[j,k] += 0.5*s0[j,k]*dt	# first half time integration of u

            rxm12 = 0.5*(rx0[j,k]+rx0[j-1,k]) + 0.5*alphax*(s0[j,k]-s0[j-1,k])
            rxp12 = 0.5*(rx0[j+1,k]+rx0[j,k]) + 0.5*alphax*(s0[j+1,k]-s0[j,k])

            rym12 = 0.5*(ry0[j,k]+ry0[j,k-1]) + 0.5*alphay*(s0[j,k]-s0[j,k-1])
            ryp12 = 0.5*(ry0[j,k+1]+ry0[j,k]) + 0.5*alphay*(s0[j,k+1]-s0[j,k])

            smx12 = 0.5*(s0[j,k]+s0[j-1,k]) + 0.5*alphax*(rx0[j,k]-rx0[j-1,k])
            spx12 = 0.5*(s0[j+1,k]+s0[j,k]) + 0.5*alphax*(rx0[j+1,k]-rx0[j,k])

            smy12 = 0.5*(s0[j,k]+s0[j,k-1]) + 0.5*alphay*(ry0[j,k]-ry0[j,k-1])
            spy12 = 0.5*(s0[j,k+1]+s0[j,k]) + 0.5*alphay*(ry0[j,k+1]-ry0[j,k])

            rx[j,k] += alphax*(spx12-smx12) \
                       + 0.125*alphax*alphay*(ry0[j+1,k+1] - ry0[j-1,k+1]
                                               - ry0[j+1,k-1] + ry0[j-1,k-1])
            ry[j,k] += alphay*(spy12-smy12) \
                       + 0.125*alphax*alphay*(rx0[j+1,k+1] - rx0[j-1,k+1] \
                                               - rx0[j+1,k-1] + rx0[j-1,k-1])
            s[j,k] += alphax*(rxp12-rxm12) + alphay*(ryp12-rym12)

            u[j,k] += 0.5*s[j,k]*dt	# second half time integration of u

def lax_wendroff_step_np(u, rx, ry, s, t, dt):

    u[0,:],u[J-1,:],u[:,0],u[:,J-1] = 4*[0.0]	# boundary conditions
    rx[0,:],rx[J-1,:],rx[:,0],rx[:,J-1] = 4*[0.0]
    ry[0,:],ry[J-1,:],ry[:,0],ry[:,J-1] = 4*[0.0]
    s[0,:],s[J-1,:],s[:,0],s[:,J-1] = 4*[0.0]

    if ripple:
        u[0,:] = np.sin(np.pi*t)
        rx[0,:] = 0
        ry[0,:] = 0
        s[0,:] = np.pi*np.cos(np.pi*t)

    rx0 = np.copy(rx)
    ry0 = np.copy(ry)

    alphax = dt/DX
    alphay = dt/DY

    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # first half time integration of u

    rxm12 = 0.5*(rx[1:-1,1:-1]+rx[:-2,1:-1]) \
              + 0.5*alphax*(s[1:-1,1:-1]-s[:-2,1:-1])
    rxp12 = 0.5*(rx[2:,1:-1]+rx[1:-1,1:-1]) \
              + 0.5*alphax*(s[2:,1:-1]-s[1:-1,1:-1])

    rym12 = 0.5*(ry[1:-1,1:-1]+ry[1:-1,:-2]) \
              + 0.5*alphay*(s[1:-1,1:-1]-s[1:-1,:-2])
    ryp12 = 0.5*(ry[1:-1,2:]+ry[1:-1,1:-1]) \
              + 0.5*alphay*(s[1:-1,2:]-s[1:-1,1:-1])

    smx12 = 0.5*(s[1:-1,1:-1]+s[:-2,1:-1]) \
              + 0.5*alphax*(rx[1:-1,1:-1]-rx[:-2,1:-1])
    spx12 = 0.5*(s[2:,1:-1]+s[1:-1,1:-1]) \
              + 0.5*alphax*(rx[2:,1:-1]-rx[1:-1,1:-1])

    smy12 = 0.5*(s[1:-1,1:-1]+s[1:-1,:-2]) \
              + 0.5*alphay*(ry[1:-1,1:-1]-ry[1:-1,:-2])
    spy12 = 0.5*(s[1:-1,2:]+s[1:-1,1:-1]) \
              + 0.5*alphay*(ry[1:-1,2:]-ry[1:-1,1:-1])
    
    rx[1:-1,1:-1] += alphax*(spx12-smx12) \
                       + 0.125*alphax*alphay*(ry0[2:,2:] - ry0[:-2,2:]
                                              - ry0[2:,:-2] + ry0[:-2,:-2])
    ry[1:-1,1:-1] += alphay*(spy12-smy12) \
                       + 0.125*alphax*alphay*(rx0[2:,2:] - rx0[:-2,2:] \
                                               - rx0[2:,:-2] + rx0[:-2,:-2])
    s[1:-1,1:-1] += alphax*(rxp12-rxm12) + alphay*(ryp12-rym12)

    u[1:-1,1:-1] += 0.5*s[1:-1,1:-1]*dt	    # second half time integration of u

def timestep(rho, mom, dx, courant):

    # Apply the extended CFL condition.
    
    return courant*dx/math.sqrt(np.max(cs2(rho)+(mom/rho)**2))

def display_data(x, y, u, t, scheme):
    if scheme == 0:
        title = 'FTCS'
    elif scheme == 1:
        title = 'Lax'
    else:
        title = 'Lax-Wendroff'
    if t == 0.0:
        title += ': Initial conditions'
    else:
        title += ': time = '+'%.2f'%(t)
        plt.cla()
    title = 'time = '+'%.2f'%(t)
    plt.title(title)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.xlim(x.min(), x.max())
    plt.ylim(y.min(), y.max())

    # np.imshow() treats the first index as row number, second as
    # column number, starting at top left.
    
    cax = plt.imshow(np.transpose(np.flip(u, axis=1)),
                     extent=(x.min(), x.max(), y.min(), y.max()),
                     vmin=-1.0, vmax=1.0)
    if t == 0.0: plt.colorbar(cax)
    plt.pause(0.001)

def main(argv):
    global AMP, COURANT, DTPLOT, GAMMA, whichIC, J, TMAX, ripple
    global DX, DY, K, CS

    opts, args = getopt.getopt(sys.argv[1:], "a:c:d:g:i:j:nrs:t:")

    scheme = SCHEME
    courant = COURANT
    numpy = True
    ripple = False

    for o, a in opts:
        if   o == '-a':		AMP = float(a)
        elif o == "-c":		courant = float(a)
        elif o == '-d':		DTPLOT = float(a)
        elif o == '-g':		GAMMA = float(a)
        elif o == "-i": 	whichIC = int(a)
        elif o == "-j": 	J = int(a)
        elif o == "-n": 	numpy = not numpy
        elif o == "-r":     ripple = not ripple
        elif o == "-s":     scheme = int(a)
        elif o == '-t':		TMAX = float(a)
        else:
            print("unexpected argument", o)

    if scheme == 0:
        if numpy:
            step = ftcs_step_np
        else:
            step = ftcs_step
    elif scheme == 1:
        if numpy:
            step = lax_step_np
        else:
            step = lax_step
    else:
        if numpy:
            step = lax_wendroff_step_np
        else:
            step = lax_wendroff_step

    print('J =', J, ' IC =', whichIC, ' scheme =', scheme, \
          ' courant =', courant, ' numpy =', numpy)

    DX = (XMAX-XMIN)/(J-1.0)
    DY = (YMAX-YMIN)/(J-1.0)
    x = np.linspace(XMIN, XMAX, J)
    y = np.linspace(YMIN, YMAX+DY, J)
    u = np.zeros((J,J))
    rx = np.zeros((J,J))
    ry = np.zeros((J,J))
    s = np.zeros((J,J))
    initialize(u, rx, ry, s)

    t = 0.0
    dt = DT
    if scheme > 0: dt = courant*DX/math.sqrt(2.)
    dtplot = DTPLOT
    if dtplot > dt: dtplot = dt
    tplot = dtplot

    display_data(x, y, u, t, scheme)

    while t < TMAX-0.5*dt:
        if t+dt > TMAX: dt = TMAX + 0.0001*dt - t
        step(u, rx, ry, s, t, dt)
        t += dt
        if t > tplot-0.5*dt or t >= TMAX:
            display_data(x, y, u, t, scheme)
            while t > tplot-0.5*dt: tplot += dtplot

    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
