import sys
import math
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import matplotlib.cm as cm

K = 9.0e9
DS = 0.01
R_MAX = 50

# Calculate the potential at point (x, y).

def potential(x, y, q, xq, yq):		# numpy version

    dr = np.sqrt((x-xq)**2+(y-yq)**2)

    phi = -q[dr>0]*np.log(dr[dr>0])	# 2-D
    #phi = q[dr>0]/dr[dr>0]			# 3-D

    return K*phi.sum()

# Calculate the field at point (x, y).

def field(x, y, q, xq, yq):		# numpy version

    dx = x-xq
    dy = y-yq
    dr2 = dx**2+dy**2

    Ex = q[dr2>0]*dx[dr2>0]/dr2[dr2>0]	# 2-D
    Ey = q[dr2>0]*dy[dr2>0]/dr2[dr2>0]

    return K*Ex.sum(),K*Ey.sum()

A = 1.0
B = 1.0

def ellipse(x, y):
    return (x/A)**2 + (y/B)**2

def inside(x, y):
    return ellipse(x, y) <= 1

def well_inside(x, y):
    return ellipse(x, y) <= 0.95

def update(q, xq, yq, phi, i, delta):

    x0 = xq[i]
    y0 = yq[i]
    phi0 = phi[i]

    Ex,Ey = field(xq[i], yq[i], q, xq, yq)
    E = math.sqrt(Ex**2+Ey**2)

    ds = random.uniform(0, delta)/E
    dx = ds*Ex
    dy = ds*Ey
    xq[i] += dx
    yq[i] += dy

    if not inside(xq[i], yq[i]):
        x0 = xq[i] - dx
        y0 = yq[i] - dy
        e0 = ellipse(x0, y0)
        e1 = ellipse(xq[i], yq[i])
        lam = 0.999*(1 - e0)/(e1 - e0)
        xq[i] = x0 + lam*dx
        yq[i] = y0 + lam*dy
        #print('i =', i, 'e0 =', e0, 'e1 =', e1, 'e =', ellipse(xq[i], yq[i]))

    return x0,y0,phi0

def setup_charges(nq):
    q = np.random.uniform(1.e-6, 1.e-6, nq)
    xx = np.random.uniform(-1., 1., 10*nq)
    yy = np.random.uniform(-1., 1., 10*nq)
    ii = inside(xx, yy)
    xq = xx[ii][0:nq]
    yq = yy[ii][0:nq]
    return q, xq, yq

if __name__ == "__main__" :

    nq = 1000
    seed = 3

    if len(sys.argv) > 1: nq = int(sys.argv[1])
    if len(sys.argv) > 2: seed = int(sys.argv[2])

    random.seed(seed)
    np.random.seed(seed)
    q, xq, yq = setup_charges(nq)
    #print len(q), len(xq), len(yq)

    xinit = xq.copy()
    yinit = yq.copy()

    phi = np.zeros(nq)
    for i in range(nq):
        phi[i] = potential(xq[i], yq[i], q, xq, yq)

    delta = 0.05
    fac = 2**0.5
    tol = 1.e-4

    phimin = 1.e7
    phimax = -1.e7
    phimean = 0.0

    ntot = 0
    nlast = 0
    i = 0
    nacc = 0
    points = [xinit, yinit]
    while ntot-nlast < 20*nq:

        #i = random.randint(0, nq-1)	# cycle through charges
        x0,y0,phi0 = update(q, xq, yq, phi, i, delta)
        ntot += 1

        phi1 = potential(xq[i], yq[i], q, xq, yq)

        if phi1 >= phi0 - tol:
            xq[i] = x0
            yq[i] = y0
            phi[i] = phi0
            #print(i, phi0, phi1, 'reject', ntot-nlast)
        else:
            phi[i] = phi1
            nlast = ntot
            #print i, phi0, phi1, 'accept'
            nacc += 1
            if nacc%100 == 0:
                points.append([xq.copy(), yq.copy()])

        i += 1
        if i >= nq: i = 0

    print('ntot =', ntot, 'nacc =', nacc)

    def compute_equipotential(xi, yi, q, xq, yq, direction):

        xstart = xi
        ystart = yi
        phistart = potential(xi, yi, q, xq, yq)
        dprev2 = 0.
        dcurr2 = 0.
        dmin2 = 0.01				# require minimum inside dmin
        steps = 0

        xplot = [xi]
        yplot = [yi]

        while xi*xi+yi*yi < R_MAX*R_MAX and steps < 10000:

            Ex,Ey = field(xi, yi, q, xq, yq)
            E = math.sqrt(Ex**2 + Ey**2)
            ds = DS

            dx = -direction*ds*Ey/E		# first order
            dy = direction*ds*Ex/E

            if 1:
                xx = xi + dx
                yy = yi + dy
                Exx,Eyy = field(xx, yy, q, xq, yq)
                Exmean = 0.5*(Ex+Exx)
                Eymean = 0.5*(Ey+Eyy)
                E = math.sqrt(Exmean**2 + Eymean**2)

                dx = -direction*ds*Eymean/E	# second order
                dy = direction*ds*Exmean/E

            xi += dx
            yi += dy

            if 1:
                dphi = potential(xi, yi, q, xq, yq) - phistart
                if dphi != 0:
                    Ex,Ey = field(xi, yi, q, xq, yq)
                    E2 = Ex**2 + Ey**2
                    xi += dphi*Ex/E2		# corrector
                    yi += dphi*Ey/E2

            # Stop at the first minimum in |x-xstart|.

            dnext2 = (xi-xstart)**2 + (yi-ystart)**2
            if dcurr2 < dmin2 and dcurr2 < dprev2 and dnext2 > dcurr2: break

            dprev2 = dcurr2
            dcurr2 = dnext2
            steps += 1

            xplot.append(xi)
            yplot.append(yi)

        return xplot,yplot

    def color(phi):
        phi = 2*(phi - phimean)/(phimax-phimin)
        if phi >= 0:
            return cm.rainbow(-0.5*math.log10(max(1.e-3,
                                                  min(1., phi)))/3)
        else:
            return cm.rainbow(1 + 0.5*math.log10(max(1.e-3,
                                                     min(1., -phi)))/3)

    def onClick(event):
        if event.inaxes is not None:
            x = event.xdata
            y = event.ydata
            phi = potential(x, y, q, xq, yq)
            #r2 = (xq-x)**2 + (yq-y)**2
            #i = np.argmin(r2)
            #Ex,Ey = field(xq[i], yq[i], q, xq, yq)
            #print i, xq[i], yq[i], Ex, Ey
            print('potential =', phi)
            xc,yc = compute_equipotential(x, y, q, xq, yq, 1)
            plt.plot(xc, yc, color=color(phi))

    def plot(ii):
        if ii < len(points):
            p.set_data(points[ii][0], points[ii][1])
            plt.title(str(ii))
        return p,

    phimin = 1.e12
    phimax = -1.e12
    for x in np.linspace(-1., 1., 20):
        for y in np.linspace(-1., 1., 20):
            phi = potential(x, y, q, xq, yq)
            if phi < phimin: phimin = phi
            if phi > phimax: phimax = phi
    print('phimin, phimax =', phimin, phimax)
    phimean = 0.5*(phimin+phimax)

    fig = plt.figure()
    plt.xlim(-1.1,1.1)
    plt.ylim(-1.1,1.1)
    ax = plt.gca()
    ax.set_aspect('equal')
    p, = plt.plot(xinit, yinit, 'b.', markersize=1)
    fig.canvas.mpl_connect('button_press_event', onClick)
    ani = anim.FuncAnimation(fig, plot, blit=False, interval=1)
    plt.show()
