#!/usr/bin/env python

import sys
import math
import matplotlib.pyplot as plt
import matplotlib.cm as cm

K = 9.0e9
PHI_0 = 1.e6
D_PHI = 1.e2
DS = 0.01
R_MAX = 50

def field(x, y, q, xq, yq):
    Ex = 0.
    Ey = 0.
    
    for i in range(len(q)):

        dx = x - xq[i]
        dy = y - yq[i]
        dr2 = dx*dx + dy*dy
        dr3i = K*q[i]/(dr2*math.sqrt(dr2));

        Ex += dx*dr3i
        Ey += dy*dr3i

    return Ex,Ey

def potential(x, y, q, xq, yq):
    phi = 0.
    for i in range(len(q)):

        dx = x - xq[i]
        dy = y - yq[i]
        dr = math.sqrt(dx*dx + dy*dy)

        phi += K*q[i]/dr

    return phi

def compute_field_line(xi, yi, q, xq, yq, direction):

    xplot = [xi]
    yplot = [yi]

    while xi*xi+yi*yi < R_MAX*R_MAX \
            and abs(potential(xi, yi, q, xq, yq)) < PHI_0:

        Ex,Ey = field(xi, yi, q, xq, yq)
        E = math.sqrt(Ex*Ex + Ey*Ey)

        ds = DS
        if ds > D_PHI/E: ds = D_PHI/E

        dx = direction*ds*Ex/E			# first order
        dy = direction*ds*Ey/E

        if 1:
            xx = xi + dx
            yy = yi + dy
            Exx,Eyy = field(xx, yy, q, xq, yq)
            Exmean = 0.5*(Ex+Exx)
            Eymean = 0.5*(Ey+Eyy)
            E = math.sqrt(Exmean*Exmean + Eymean*Eymean)

            dx = direction*ds*Exmean/E		# second order
            dy = direction*ds*Eymean/E

        xi += dx
        yi += dy

        xplot.append(xi)
        yplot.append(yi)

    return xplot,yplot

def plot_field_line(xi, yi, q, xq, yq, direction, c):

    xplot,yplot = compute_field_line(xi, yi, q, xq, yq, direction)
    plt.plot(xplot, yplot, color=c, zorder=1)

import numpy as np

def setup_charges():

    # Use numpy arrays so we can use xq[(q>=0)] syntax below.
    
    q = np.zeros(4)
    c = ['', '', '', '']
    xq = np.zeros(4)
    yq = np.zeros(4)

    q[0] = 2.e-6
    q[1] = -1.e-6
    q[2] = -2.e-6
    q[3] = -1.e-6

    c[0] = 'xkcd:pink'
    c[1] = 'xkcd:mint'
    c[2] = 'xkcd:light blue'
    c[3] = 'xkcd:cyan'

    xq[0] = 0.
    xq[1] = -1.
    xq[2] = 1.
    xq[3] = 0.

    yq[0] = 0.
    yq[1] = 1.
    yq[2] = 1.
    yq[3] = -0.73

    return q, xq, yq, c

if __name__ == "__main__" :

    nth = 5
    lim = 5
    if len(sys.argv) > 1: nth = int(sys.argv[1])
    if len(sys.argv) > 2: lim = int(sys.argv[2])

    fig,ax = plt.subplots()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('field lines')
    plt.xlim(-lim, lim)
    plt.ylim(-lim, lim)
    ax.set_aspect('equal')

    q, xq, yq, c = setup_charges()
    plt.scatter(xq[(q>=0)], yq[(q>=0)], c='r', s=10, zorder=2)
    plt.scatter(xq[(q<0)], yq[(q<0)], c='b', s=10, zorder=2)

    for iq in range(len(q)):
        #c = cm.hsv(float(iq)/len(q), 1)
        for ith in range(nth):
            dr = 1.1*K*abs(q[iq])/PHI_0
            theta = 2*ith*math.pi/nth
            xi = xq[iq] + dr*math.cos(theta)
            yi = yq[iq] + dr*math.sin(theta)
            plot_field_line(xi, yi, q, xq, yq, math.copysign(1, q[iq]), c[iq])

    plt.show()
