#!/usr/bin/env python

# Homework 2, problem 4.  Conventions and notation follow the C
# version, hw2.4.c, wherever possible.
#
# Run with
#
#	hw2.4.py

import sys
import math
from numpy import zeros
from numpy.linalg import eig

#--------------------------------------------------------------------------
# Functions below are all translated directly from the C version.
#--------------------------------------------------------------------------

# Recurrence relation:  H(n+1) = 2x H(n) - 2n H(n-1)
# NOTE...  Recursion is compact, but *very* inefficient!
# DO NOT use this function in general (or for large N)!!

def hermite(n, x):			# return Hn(x)
  if n < 0:
    return 0
  elif n == 0:
    return 1
  else:
    return 2*x*hermite(n-1, x) - 2*(n-1)*hermite(n-2, x)

def fac(n):				# return n!
  f = 1
  if n > 1:
    for i in range(2,n+1):
      f = f*i
  return f

def two(n):				# return 2^n
  return int(math.pow(2,n))

#--------------------------------------------------------------------------

def psi(n, x):				# normalized wavefunction
  return hermite(n, x) * math.exp(-0.5*x*x) \
		* math.pow(math.pi, -0.25) / math.sqrt(two(n)*fac(n))

def psi_product(m, n, x):
  return psi(m, x) * psi(n, x)		# (very inefficient...)

def delta(n, m):
  if n == m:
    return 1
  else:
    return 0

def trap(m, n, f, xmin, xmax, npoints):	# trapezoid rule
  h = (xmax - xmin) / npoints
  sum = 0.5*(f(m, n, xmin) + f(m, n, xmax))
  for i in range(1,npoints):
    sum = sum + f(m, n, xmin+i*h)
  sum = sum*h
  return sum

def v_int(m, n, npoints):		# return Vnm
  return trap(n, m, psi_product, \
	       -math.sqrt(2.0), math.sqrt(2.0), npoints)

def h(n, m, npoints):			# return Hnm

  # Note that Knm is analytic -- see the handout. */

  return (n+2.5)*delta(n, m) \
		 - 2*v_int(n, m, npoints) \
		 - 0.5*math.sqrt((n+1.)*(n+2.))*delta(n, m-2) \
		 - 0.5*math.sqrt(n*(n-1.))*delta(n, m+2)

#--------------------------------------------------------------------------

def vprint(id, a, n):			# print a real vector
  print "%s[%d]:" % (id, n)
  print "   ",
  for i in range(n):
    print "%9.6f" % (a[i]),
  print ""

def mprint(id, a, n):			# print a real square matrix
  print "%s[%dx%d]:" % (id, n, n)
  for i in range(n):
    print "   ",
    for j in range(n):
      print "%9.6f" % (a[i,j]),
    print ""

#--------------------------------------------------------------------------

N = 10			# default number of basis functions to use
NMAX = 22		# recursive stuff is too slow above this number!
NMIN = 50
XMIN = 5.0

nbase = N
if len(sys.argv) > 1: nbase = int(sys.argv[1])
if nbase > NMAX: nbase = NMAX

npoints = 10*nbase	# 10 points per oscillation...
if npoints < NMIN: npoints = NMIN

# Create the matrix H.

H = zeros((nbase,nbase))
for n in range(nbase):
  for m in range(nbase):
    H[n,m] = h(n, m, npoints)

if nbase <= 6: mprint("H", H, nbase)

d,v = eig(H)

# Print out the lowest two eigenvalues. */

if nbase <= 6: vprint("d", d, nbase)

e1 = 1.e100
e2 = 1.e100

for n in range(nbase):
  if (d[n] < e1):
    e2 = e1
    e1 = d[n]
  elif d[n] < e2:
    e2 = d[n]

print "%d basis functions:  E1 = %f, E2 = %f" % (nbase, e1, e2)

