import sys
import math
import random
import numpy as np
import matplotlib.pyplot as plt

N = 20
seed = 42
C = 1.0
xmin = 0.0
xmax = 1.0

def refr(x, y):
    if x <= 0.5:
        return 1.0
    else:
        return 1.5

def transit(x, y):
    t = 0.0
    for i in range(len(x)-1):
        xx = 0.5*(x[i]+x[i+1])
        yy = 0.5*(y[i]+y[i+1])
        ni = refr(xx, yy)
        dx = x[i+1] - x[i]
        dy = y[i+1] - y[i]
        ds = math.sqrt(dx*dx + dy*dy)
        t += ni*ds/C
    return t

if len(sys.argv) > 1: N = int(sys.argv[1])
if len(sys.argv) > 2: seed = int(sys.argv[2])

random.seed(seed)

# Initialize the path from (0,0) to (1,1).

x = np.linspace(xmin, xmax, N+1)
y = np.linspace(xmin, xmax, N+1)
t = transit(x, y)
print('initial t =', t)
tmin = t

# Loop until tmin doesn't decrease for a specified number of trials.

deltay = 0.1			# 1./N
trials = 0
count = 0
while count < 1000:		# 100*N:

    # Randomly change one point between 1 and N-1 (y[0], y[N] are fixed).

    trials += 1
    i = random.randint(1, N-1)

    y0 = y[i]
    y[i] += random.uniform(-deltay, deltay)
    t = transit(x, y)

    if t < tmin:		# accept
        count = 0
        tmin = t
    else:				# discard
        y[i] = y0
        count += 1

print('tmin =', tmin, 'trials =', trials, 'count =', count)
plt.plot(x, y)
plt.show()
