from __future__ import print_function
from dtw import dtw
import numpy as np
import argparse

result_folder = 'results/csv/'

def shorten(a, ratio=1/3):
    x = int(len(a)*ratio)
    b = a[x:len(a)-x]
    print(len(a), len(b))
    if len(b) < 5:
        return a
    return b

def euclid(a, b):
	return abs(a-b)

def sqr_dst(a, b):
	return (a-b)**2

def seq_delta(x):
	return [b-a for a, b in zip(x, x[1:])]

def rescale_meanmax(x):
	a = x - np.mean(x)
	b = max(abs(min(a)), abs(max(a)))
	return a / b

def rescale_minmax(x):
	#print(x)
	#print(np.means(x))
	a = x - min(x)
	b = max(a)
	return a / b

def rescale_meansd(x):
	a = x - np.mean(x)
	b = np.std(a)
	return a / b

def compute_dtw(a, b, loc=False, nonlin=1, norm=None, dist_f = euclid):
	dist, cost, acc, path = dtw(a, b, dist_f, local=loc, nonlinear=nonlin, normalize=norm)
	return dist

def load_fns(ifile):
	f = open(ifile)
	return  [line.split() for line in f.readlines()]

def load_means(fn):
	x = [float(i) for i in open(fn).read().split()]
	return np.array(x)

def write_output(fn, vals):
	with open(result_folder+fn, 'w') as f:
		[print(x, file=f) for x in vals]

parser = argparse.ArgumentParser()
parser.add_argument('i', type=int, default=0,
    help='Test selection')
args = parser.parse_args()

######  EUCLID DIST  ######

# standard

if args.i == 1:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('good_euclid.csv', result)

elif args.i == 2:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('bad_euclid.csv', result)

# scale MIN-MAX
elif args.i == 3:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_minmax, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('good_minmax_euclid.csv', result)

elif args.i == 4:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_minmax, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('bad_minmax_euclid.csv', result)

# scale MEAN-MAX
elif args.i == 5:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meanmax, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('good_meanmax_euclid.csv', result)

elif args.i == 6:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meanmax, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('bad_meanmax_euclid.csv', result)

# scale MEAN-SD
elif args.i == 7:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meansd, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('good_meansd_euclid.csv', result)

elif args.i == 8:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meansd, x) for x in means]
	result = [compute_dtw(a, b) for a, b in means]
	write_output('bad_meansd_euclid.csv', result)

######  SQUARE DIST  #######
# same tests as before

if args.i == 9:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('good_sqrdts.csv', result)

elif args.i == 10:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('bad_sqrdts.csv', result)

elif args.i == 11:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_minmax, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('good_minmax_sqrdts.csv', result)

elif args.i == 12:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_minmax, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('bad_minmax_sqrdts.csv', result)

elif args.i == 13:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meanmax, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('good_meanmax_sqrdts.csv', result)

elif args.i == 14:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meanmax, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('bad_meanmax_sqrdts.csv', result)

elif args.i == 15:
	fns = load_fns('good_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meansd, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('good_meansd_sqrdts.csv', result)

elif args.i == 16:
	fns = load_fns('bad_fits.txt')
	means = [[load_means(i) for i in x] for x in fns]
	means = [map(rescale_meansd, x) for x in means]
	result = [compute_dtw(a, b, dist_f=sqr_dst) for a, b in means]
	write_output('bad_meansd_sqrdts.csv', result)

# postprocessing tests

elif args.i == 17:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='min') for a, b in means]
    write_output('good_meansd_euclid_min.csv', result)
elif args.i == 18:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='min') for a, b in means]
    write_output('bad_meansd_euclid_min.csv', result)
elif args.i == 19:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='max') for a, b in means]
    write_output('good_meansd_euclid_max.csv', result)
elif args.i == 20:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='max') for a, b in means]
    write_output('bad_meansd_euclid_max.csv', result)
elif args.i == 21:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='path') for a, b in means]
    write_output('good_meansd_euclid_path.csv', result)
elif args.i == 22:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='path') for a, b in means]
    write_output('bad_meansd_euclid_path.csv', result)
elif args.i == 23:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='sum') for a, b in means]
    write_output('good_meansd_euclid_sum.csv', result)
elif args.i == 24:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='sum') for a, b in means]
    write_output('bad_meansd_euclid_sum.csv', result)

# glob-loc alignment (A is ref, B is sample)

elif args.i == 25:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='path', loc=True) for a, b in means]
    write_output('good_meansd_euclid_gloloc_path.csv', result)
elif args.i == 26:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='path', loc=True) for a, b in means]
    write_output('bad_meansd_euclid_gloloc_path.csv', result)
elif args.i == 27:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='max', loc=True) for a, b in means]
    write_output('good_meansd_euclid_gloloc_max.csv', result)
elif args.i == 28:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='max', loc=True) for a, b in means]
    write_output('bad_meansd_euclid_gloloc_max.csv', result)
elif args.i == 29:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='min', loc=True) for a, b in means]
    write_output('good_meansd_euclid_gloloc_min.csv', result)
elif args.i == 30:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, norm='min', loc=True) for a, b in means]
    write_output('bad_meansd_euclid_gloloc_min.csv', result)
elif args.i == 31:
    fns = load_fns('good_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, loc=True) for a, b in means]
    write_output('good_meansd_euclid_gloloc.csv', result)
elif args.i == 32:
    fns = load_fns('bad_fits.txt')
    means = [[load_means(i) for i in x] for x in fns]
    means = [map(rescale_meansd, x) for x in means]
    means = [(a, shorten(b)) for a, b in means]
    result = [compute_dtw(a, b, dist_f=euclid, loc=True) for a, b in means]
    write_output('bad_meansd_euclid_gloloc.csv', result)
