# Create two colliding Gaussian wave packets and follow their
# evolution as they cross the computational domain.

import sys
import getopt
import numpy as np
import matplotlib.pyplot as plt

# Physics:

HBAR	= 1.0
M	= 1.0
DELX    = 2.5

XBAR1    = -20.
PBAR1    = 5.0
XBAR2    = 20.
PBAR2    = -5.0

# Range and time span:

XMIN    = -40.0
XMAX    = 40.0
TMAX    = 20.0

# Numerics:

J       = 1001
DX      = (XMAX-XMIN)/(J-1.0)
DT      = 0.01
ALPHA   = HBAR*DT/(2*M*DX**2)
DTPLOT  = 0.1
which   = 0

I       = 1j		# physicist's square root of -1!

def wavepacket(x, xb, pb):
    return np.exp(-(x-xb)**2/(4*DELX**2) + I*pb*x/HBAR) \
               / (2*np.pi*DELX**2)**0.25

def initialize(x):
    return wavepacket(x, XBAR1, PBAR1) + wavepacket(x, XBAR2, PBAR2)

def tridiag(a, b, c, r):
    n = len(r)
    u = np.zeros(n, dtype=complex)
    gam = np.zeros(n, dtype=complex)
    bet = b[0]
    u[0] = r[0]/bet

    for j in range(1,n):
        gam[j] = c[j-1]/bet
        bet = b[j] - a[j]*gam[j]
        u[j] = (r[j]-a[j]*u[j-1])/bet

    for j in range(n-2,-1,-1):
        u[j] -= gam[j+1]*u[j+1]

    return u

def cn_step(a, b, c, r, u):
    r[1:-1] = 0.5*I*ALPHA*(u[:-2]+u[2:]) + (1-I*ALPHA)*u[1:-1]
    return tridiag(a, b, c, r)	# tridiag is fast enough in 1D...

def display_data(x, u, t):
    if t > 0.0: plt.cla()
    title = 'time = '+'%.2f'%(t)
    plt.title(title)
    plt.xlabel('x')
    if which == 0:
        plt.ylim(0.0, 0.15)
        plt.ylabel('$|\psi|^2$')
        plt.plot(x, np.abs(u)**2)
    else:
        plt.ylim(-0.5, 0.5)
        plt.ylabel('$|\psi|$')
        plt.plot(x, np.abs(u))
        plt.plot(x, np.real(u))
        plt.plot(x, np.imag(u))
    plt.pause(0.001)

def prob(u, dx):		# 1-d trapezoid
    pp = np.abs(u)**2
    return dx*(0.5*pp[0]+np.sum(pp[1:-1])+0.5*pp[-1])
    
def main(argv):
    global ALPHA

    dt = DT
    if len(sys.argv) > 1: dt = float(sys.argv[1])
    ALPHA = HBAR*dt/(2*M*DX**2)

    dtplot = DTPLOT
    if dtplot < dt: dtplot = dt
    tplot = dtplot

    x = np.linspace(XMIN, XMAX, J)
    u = initialize(x)
    t = 0.0

    int0 = prob(u, DX)
    u /= int0**0.5
    
    display_data(x, u, t)
    
    a = -0.5*I*ALPHA*np.ones(J)
    b = (1+I*ALPHA)*np.ones(J)
    c = a.copy()
    r = np.zeros(J, dtype=complex)

    # Boundary conditions:

    b[0] = 1.0
    c[0] = 0.0		# BC b[0]*u[0] + c[0]*u[1]] = r[0]
    r[0] = 0.0
    a[-1] = 0.0
    b[-1] = 1.0		# BC b[-1]*u[-1] + a[-1]*u[-2] = r[-1]
    r[-1] = 0.0

    # Note that the vectors a, b, and c are constant for fixed time step.

    int0 = prob(u, DX)
    print('initial integral =', int0)
    print('maximum x =', x[np.argmax(np.abs(u))])
    
    while t < TMAX-0.5*dt:
        u = cn_step(a, b, c, r, u)
        t += dt
        if t > tplot-0.5*dt:
            display_data(x, u, t)
            tplot += dtplot
    
    int1 = prob(u, DX)
    print('final integral =', int1, 'error =', int1/int0-1.0)
    print('maximum x =', x[np.argmax(np.abs(u))])
    
    plt.show()

if __name__ == "__main__" :
    main(sys.argv)
