
from pylab import *
#from visual import *
from sys import *

# To run: python qho.py [c0, c1, c2, c3]
# where c0... are the coefficients for the n-th mode.

def Hermite(X,n):

       if n<0 or n>3: return [0 for x in x]
       if (n == 0)  : return [exp(-x*x/2.) for x in X]
       if (n == 1)  : return [2*x*exp(-x*x/2.) for x in X]
       if (n == 2)  : return [(4*x**2-2)*exp(-x*x/2.) for x in X]
       if (n == 3)  : return [(8*x**3-12*x)*exp(-x*x/2.) for x in X]

# Constants that we are setting to unity
omega0,hbar,m = 1.0,1.0,1.0

# Graph parameters
xmax,tmax = 3. , 40.
dx, dt    = 0.1, .1

# Setting up the data arrays and inital values
X     = [-xmax + n*dx for n in xrange(int((xmax*2)/dx))]
psi2  = [0 for x in X]
psi_r = [0 for x in X]
psi_i = [0 for x in X]

# Set up the coefficients
# and the Hermite polynomials
nmax = 4
C = [0. for m in xrange(nmax)]
H = [Hermite(X,m) for m in xrange(nmax)]

# Now set the actual coefficients
if (len(argv) == 1): C[0]=1.
else :
       for i in xrange(len(argv)-1):
               C[i]=float(argv[i+1])


# normalize the coefficients
norm_factor = 1. / sqrt(sum([c**2 for c in C]))
C           = [ c*norm_factor for c in C ]

# Setup the plots
ion()
figure(0)
xc=0
myplot = plot(X,psi2,X,psi_r,X,psi_i,X,[0.  for x in X],'k--',[xc],[0],'bo')


xlabel("x")
ylabel("psi(x), K(x), psi(x)^2")

# What are some appropriate ranges?
psimax=0.
E=0.
for i in range(nmax):
       psimax=psimax+abs(C[i]*max(H[i]))
       E=E+(i+0.5)*C[i]**2
# Plot the kinetic energy instead of potential
K = [E-.5*(x**2) for x in X]
myplot[3].set_ydata(K)
ylim(-psimax**2,psimax**2)

myplot[4].set_ydata([0.9*psimax**2])

#tarr, xarr = [], []
#figure(1)
#timeplot=plot([0,tmax],[-1,1])
#xlim(0,tmax)
#ylim(-1,1)

t = 0
while (t < tmax):
       #rate(3./dt)
       xc=0
       tot=0.
       for i in xrange(len(psi_r)):
               psi_r[i], psi_i[i] = 0.0, 0.0

               for j in xrange (len(C)):
                       psi_r[i] += C[j]*cos((j+0.5)*omega0*t)*H[j][i]
                       psi_i[i] += C[j]*sin((j+0.5)*omega0*t)*H[j][i]

               psi2[i] = psi_r[i]**2+psi_i[i]**2
               tot=tot+psi2[i]*dx
               xc     += X[i]*psi2[i]*dx
       xc=xc/tot
       # Update the figures
       figure(0)
       myplot[0].set_ydata(psi2)
       myplot[1].set_ydata(psi_r)
       myplot[2].set_ydata(psi_i)
       myplot[4].set_xdata([xc])
       draw()
       #figure(1)
       #tarr.append(t)
       #xarr.append(xc)
       #timeplot[0].set_data(tarr,xarr)
       #draw()
       t += dt

show()
