#!/usr/bin/env python

# Homework 2, problem 3.  Conventions and notation follow the C
# version, hw2.3.c, wherever possible.
#
# Run with
#
#	python hw2.2.py < hw2.2.dat

import sys

from numpy import zeros
from numpy.linalg import svd
from pylab import plot, legend, show

def get_funcs(x, f, m):		# return [1, x, x^2, ..., x^{m-1}]
    f[0] = 1
    for i in range(1,m):
        f[i] = x*f[i-1]

def svbksb(u, w, vt, n, m, b):

    # Cribbed from Numerical Recipes, but we pass vt and have swapped
    # m and n.  Implicit dimensions are: u(n,m), w(m), vt(m,m), b(n).

    tmp = zeros(m)
    c = zeros(m)

    # First compute the parenthetical term in NR Eq. 15.4.17 (i-->j).
    
    for j in range(m):
        s = 0.
        if w[j] != 0:
            for i in range(n):
                s += u[i,j]*b[i]
            s /= w[j]			# s = "(U.b/w)" for column j
        tmp[j] = s

    # Complete the sum.

    for k in range(m):
        s = 0.
        for j in range(m):
            s += tmp[j]*vt[j,k]		# NB v[k,j] = vt[j,k]
        c[k] = s

    return c

def svdfit(x, y, sig, n, m):

    # Data x, y, sig indices run from 0 to n-1
    # Fitting coefficients c indices run from 0 to m-1
    # For fitting here, expect n >> m

    # Return the fitting coefficients as a vector of length m.

    # This version of svdfit is based on the Numerical Recipes version
    # (Sec. 15.4).  Note that the NR naming conventions are different
    # at different places in the book.  The discussions here and in
    # Chapter 15 refer to data arrays of dimension nxm, where we
    # expect n (data) >> m (fit).  We use the index i to run over n
    # and j for m.  Note that, in Chapter 2 (and in class), the
    # discussion of SVD itself was cast in terms of mxn matrices!

    A = zeros((n,m))			# rows = data, columns = fit
    b = zeros(n)
    f = zeros(m)

    u = zeros((n,m))
    v = zeros((m,m))
    w = zeros(m)

    # Create the design matrix A; equation to solve is A.c = b

    for i in range(n):
        get_funcs(x[i], f, m)
        for j in range(m):
            A[i,j] = f[j]/sig[i]	# check/correct i/j, n/m conventions
        b[i] = y[i]/sig[i]

    # SV Decompose: A = u.w.vt.  Note that there are some differences
    # between the NR and python impleementations:
    #
    # 1. NR function svdcmp returns v, while the python function
    #    returns vt (its transpose).
    #
    # 2. NR defines u as an nxm "column orthogonal" matrix and w as an
    #    mxm diagonal matrix, while the default python svd convention
    #    is to return u as an nxn orthogonal matrix, and w as a vector
    #    of length m, representing the diagonal elements.  The
    #    leftmost m columns of u are the same as in the NR version,
    #    and the matrix is expanded on the right by n-m columns
    #    defined to make the square matrix orthogonal.  To make the
    #    multiplication work, w must be expanded downward by an
    #    additional n-m rows, all zero.  Setting the optional second
    #    argument of svd to 1 causes it to return matrices in the same
    #    form as NR.

    u,w,vt = svd(A,1)

    # Throw away "nearly singular" elements.
    
    TOL = 1.e-5
    
    wmax = 0.0
    for j in range(m):
        if w[j] > wmax:
            wmax = w[j]
    thresh = TOL*wmax
    for j in range(m):
        if (w[j] < thresh):
            w[j] = 0.0

    c = svbksb(u, w, vt, n, m, b)
    return c

def poly(x, m, c):		# return c[0] + c[1] x + ... + c[m-1] x^{m-1}
    j = m-1
    p = c[j]
    while j > 0:
        j = j - 1
        p = p*x + c[j]
    return p;

#---------------------------------------------------------------------
#
# Read the data.

n = 500
x = zeros(n)			# x, y, and sig are arrays of length n
y = zeros(n)
sig = zeros(n)

lines = sys.stdin.readlines()
i = 0
for line in lines:
    l = line.strip().split()
    x[i] = float(l[0])
    y[i] = float(l[1])
    sig[i] = float(l[2])	# should really check len(l) = 3, etc.
    i = i + 1
    
mlist = [7]
narg = len(sys.argv)
if narg > 1:
    mlist = []
    i = 1
    while i < narg:
        mlist.append(int(sys.argv[i]))
        i = i + 1

for m in mlist:

    # Fit the data.

    c = svdfit(x, y, sig, n, m)	# c is an array of length m

    # Print the results.

    print "N =", n, "  m =", m
    for j in range(m):
        print "a%d = %f" % (j, c[j])

    # Compute chi^2.

    yfit = zeros(n)
    chi2 = 0
    for i in range(n):
        yfit[i] = poly(x[i], m, c)
        error = (y[i] - yfit[i]) / sig[i]
        chi2 += error*error

    print "chi^2 =", chi2, "\n"

    # Plot the results:

    plot(x, y, 'bo', x, yfit, 'r.')
    legend(('data', str(m)+'-parameter fit'), loc='best')

print 'Kill the graphics window to exit the program'
show()


