
/* Homework 2, problem 4.  Compile with
 *
 *	make hw2.4
 *
 * Run with
 *
 *	hw2.4  N
 *
 * where N is the number of basis functions to use.
 * See the on-line notes on this problem.
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include "nrutil.h"

#define real double

/*--------------------------------------------------------------------------*/

/* NOTE...  Recursive C is compact, but *very* inefficient!
   DO NOT use these functions in general (or for large N)!! */

/* Recurrence relation:  H(n+1) = 2x H(n) - 2n H(n-1) */

real hermite(int n, real x)		/* return Hn(x) */
{
    if (n < 0)
	return 0;
    else if (n == 0)
	return 1;
    else
	return 2*x*hermite(n-1, x) - 2*(n-1)*hermite(n-2, x);
}

real fac(int n)				/* return n! */
{
    if (n <= 1)
	return 1.0;
    else
	return n*fac(n-1);
}

real two(int n)				/* return 2^n */
{
    if (n <= 0)
	return 1.0;
    else
	return 2.0*two(n-1);
}

/*--------------------------------------------------------------------------*/

real psi(int n, real x)			/* normalized wavefunction */
{
    return hermite(n, x) * exp(-0.5*x*x)
		* pow(M_PI, -0.25) / sqrt(two(n)*fac(n));
}

real psi_product(int m, int n, real x)
{
    return psi(m, x) * psi(n, x);	/* (very inefficient...) */
}

int delta(int n, int m)
{
    return (n == m ? 1 : 0);
}

real trap(int m, int n, real (*f) (int, int, real),
	  real xmin, real xmax, int npoints)		/* trapezoid rule */
{
    real h = (xmax - xmin) / npoints;
    real sum = 0.5*(f(m, n, xmin) + f(m, n, xmax));
    int i;

    for (i = 1; i < npoints; i++)
	sum += f(m, n, xmin+i*h);
    sum *= h;

    return sum;
}

real v_int(int m, int n, int npoints)		/* retuen Vnm */
{
    return trap(n, m, psi_product, -sqrt(2.0), sqrt(2.0), npoints);
}

real h(int n, int m, int npoints)		/* return Hnm */
{
    /* Note that Knm is analytic -- see the handout. */

    return	(n+2.5)*delta(n, m)
		 - 2*v_int(n, m, npoints)
		 - 0.5*sqrt((n+1.)*(n+2.))*delta(n, m-2)
		 - 0.5*sqrt(n*(n-1.))*delta(n, m+2);
}

/*------------------------------------------------------------------------*/

void vprint(char* id, real *a, int n)		/* print a real vector */
{
    int i;
    fprintf(stderr, "%s[%d]:\n", id, n);
    fprintf(stderr, "    ");
    for (i = 1; i <= n; i++)
	fprintf(stderr, "%.4f ", a[i]);
    fprintf(stderr, "\n");
}

void mprint(char* id, real **a, int n)		/* print a real square matrix */
{
    int i, j;
    fprintf(stderr, "%s[%dx%d]:\n", id, n, n);
    for (i = 1; i <= n; i++) {
	fprintf(stderr, "    ");
	for (j = 1; j <= n; j++)
	    fprintf(stderr, "%6.3f ", a[i][j]);	/* (note specific format) */
	fprintf(stderr, "\n");
    }
}

#define N	10	/* default number of basis functions to use */
#define NMAX	22	/* recursive stuff is too slow above this number! */

#define NMIN	50
#define XMIN	5.0

main(int argc, char** argv)
{
    int n, m, nbase = N;
    int npoints;

    real **H, *d, *e, **z;
    real e1 = 1.e100, e2 = 1.e100;

    if (argc > 1) nbase = atoi(argv[1]);
    if (nbase > NMAX) nbase = NMAX;

    H = dmatrix(1,nbase,1,nbase);
    d = dvector(1,nbase);
    e = dvector(1,nbase);
    z = dmatrix(1,nbase,1,nbase);

    npoints = 10*nbase;			/* 10 points per oscillation... */
    if (npoints < NMIN) npoints = NMIN;

    /* Create the overlap matrix H. */

    for (n = 1; n <= nbase; n++)
	for (m = 1; m <= nbase; m++)
	    H[n][m] = h(n-1, m-1, npoints);

    if (nbase <= 6) mprint("H", H, nbase);

    /* Reduce to tridiagonal form. */

    tred2(H, nbase, d, e);

    /* Solve the tridiagonal system. */

    tqli(d, e, nbase, z);

    /* Print out the lowest two eigenvalues. */

    if (nbase <= 6) vprint("d", d, nbase);

    for (n = 1; n <= nbase; n++) {
	if (d[n] < e1) {
	    e2 = e1;
	    e1 = d[n];
	} else if (d[n] < e2)
	    e2 = d[n];
    }
    fprintf(stderr, "%d basis functions:  E1 = %f, E2 = %f\n", nbase, e1, e2);
}

