#! /usr/bin/env python

"""\
%(prog)s [-o outfile] <datafile1> <datafile2>

Compare analysis objects between two YODA-readable data files.
"""

from __future__ import print_function
from math import sqrt

import yoda, sys, argparse

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("ARGS", nargs=2, help="<datafile1> <datafile2>")
parser.add_argument("-o", "--output", default="-", dest="OUTPUT_FILE",
                    help="write output to the given file (default: stdout)")
parser.add_argument("-t", "--tol", type=float, default=1e-5, dest="TOL",
                    help="relative tolerance of numerical difference permitted before complaining (default: %(default)s)")
parser.add_argument("-l", "--list", action="store_true", default=False, dest="LIST",
                    help="only print paths of mismatching objects, skip diff details (default: %(default)s)")
parser.add_argument("-a", "--annotations", action="store_true", default=False, dest="ANNOTATIONS",
                    help="also compare annotations (not done by default)")
parser.add_argument("-e", "--errors", action="store_true", default=False, dest="BREAKDOWN",
                    help="also compare error breakdown if applicable (not done by default)")
parser.add_argument("-d", "--dev-histo", default=None, dest="DEV_HISTO_FILE",
                    help="write deviation histo to the given file (default: %(default)s)")
parser.add_argument("--normalize", action="store_true", default=False, dest="NORMALIZE",
                    help="normalize objects before comparing (if possible)")
parser.add_argument("-m", "--match", dest="MATCH", action="append", metavar="PATT", default=[],
                    help="only write out histograms whose path matches this regex")
parser.add_argument("-M", "--unmatch", dest="UNMATCH", action="append", metavar="PATT", default=[],
                    help="exclude histograms whose path matches this regex (e.g. -M RAW)")
parser.add_argument("--ignore-missing", action="store_true", default=False, dest="IGNORE_MISSING",
                    help="don't complain if an object in file #1 is not found in file #2")
parser.add_argument("--ignore-new", action="store_true", default=False, dest="IGNORE_NEW",
                    help="don't complain if an object in file #2 is not found in file #1")
parser.add_argument("--ignore-types", action="store_true", default=False, dest="IGNORE_TYPES",
                    help="don't complain if an object in file #2 is not of the same type as in file #1")
parser.add_argument("-q", "--quiet", action="store_true", default=False, dest="QUIET",
                    help="print nothing, only express comparison result via return code (default: %(default)s)")
args = parser.parse_args()

filenames = args.ARGS
if len(filenames) != 2:
    print("ERROR! Please supply *two* YODA files for comparison")
    sys.exit(7)


## Get data objects
aodict1 = yoda.read(filenames[0],patterns=args.MATCH, unpatterns=args.UNMATCH)
aodict2 = yoda.read(filenames[1],patterns=args.MATCH, unpatterns=args.UNMATCH)

CLEAN = True
def log(msg):
    if not args.QUIET and not args.LIST:
        print(msg)

## Check number of data objects in each file
if len(aodict1) != len(aodict2):
    CLEAN = False
    if not (args.IGNORE_MISSING or args.IGNORE_NEW):
        log("Different numbers of data objects in %s and %s" % tuple(filenames[:2]))
elif sorted(list(aodict1.keys())) != sorted(list(aodict2.keys())):
    CLEAN = False
    if not (args.IGNORE_MISSING or args.IGNORE_NEW):
        log("Different data object paths in %s and %s" % tuple(filenames[:2]))

## A slightly tolerant numerical comparison function
def eq(a, b):
    if a == b:
        return True
    from math import isnan
    if type(a) is type(b) is float and isnan(a) and isnan(b):
        return True
    ## Type-check: be a bit careful re. int vs. float
    if type(a) is not type(b) and not all(type(x) in (int, float) for x in (a,b)):
        return False
    ## Recursively call on pairs of components if a and b are iterables
    if hasattr(a, "__iter__"):
        return all(eq(*pair) for pair in zip(a, b))
    ## Check if a and b have equal magnitude but opposite sign
    if a == -b:
        return False
    ## Finally apply a tolerant numerical comparison on numeric types
    # TODO: Be careful with values on either side of zero -- abs(float(a)) etc. on denom?
    return abs(float(a) - float(b))/(float(a) + float(b)) < args.TOL

def ptstr(pt):
    vstr1 = "{x:.5g} + {ex[1]:.5g} - {ex[0]:.5g}"
    vstr2 = "{x:.5g} +- {ex[0]:.5g}"
    vstrs = []
    if pt.dim() >= 1:
        vstrs.append( (vstr2 if eq(*pt.xErrs()) else vstr1).format(x=pt.x(), ex=pt.xErrs()) )
    if pt.dim() >= 2:
        vstrs.append( (vstr2 if eq(*pt.yErrs()) else vstr1).format(x=pt.y(), ex=pt.yErrs()) )
    if pt.dim() >= 3:
        vstrs.append( (vstr2 if eq(*pt.zErrs()) else vstr1).format(x=pt.z(), ex=pt.zErrs()) )
    return "(" + ", ".join(vstrs) + ")"

if args.DEV_HISTO_FILE:
    devhist = yoda.Histo1D(40,-10.0,10.0,"/ANALYSIS/deviations")
## Compare each object pair
for path in sorted(set(list(aodict1.keys()) + list(aodict2.keys()))):

    THISCLEAN = True
    while True: #< hack to allow early exits while still reporting path-specific failures

        ## Get the object in file #1
        ao1 = aodict1.get(path, None)
        if ao1 is None:
            if not args.IGNORE_NEW:
                THISCLEAN = False
                log("Data object '%s' not found in %s" % (path, filenames[0]))
            break
        ## Get the object in file #2
        ao2 = aodict2.get(path, None)
        if ao2 is None:
            if not args.IGNORE_MISSING:
                THISCLEAN = False
                log("Data object '%s' not found in %s" % (path, filenames[1]))
            break

        ## Compare the file #1 vs. #2 object types
        if type(ao1) is not type(ao2):
            if not args.IGNORE_TYPES:
              THISCLEAN = False
              log("Data objects with path '%s' have different types (%s and %s) in %s and %s" % \
                  (path, str(type(ao1)), str(type(ao2)), filenames[0], filenames[1]))
              break
        elif 'Estimate' in str(type(ao1)) and args.BREAKDOWN:
            ## Compare the error breakdown (not done by default)
            errs1 = ao1.sources()
            errs2 = ao2.sources()
            if len(errs1) != len(errs2):
                THISCLEAN = False
                log("Data objects with path '%s' have different uncertainty sources (%d and %d) in %s and %s" % \
                    (path, len(errs1), len(errs2), filenames[0], filenames[1]))
                break
            for src in set(errs1 + errs2):
                if not ao1.hasSource(src):
                    THISCLEAN = False
                    log("Data object with path '%s' is missing uncertainty source '%s' in %s" % \
                        (path, src, filenames[0]))
                    break
                if not ao2.hasSource(src):
                    THISCLEAN = False
                    log("Data object with path '%s' is missing uncertainty source '%s' in %s" % \
                        (path, src, filenames[1]))
                    break

        if args.NORMALIZE and hasattr(ao1, "normalize"):
            ao1.normalize()
            ao2.normalize()
        ## Convert to scatter representations
        try:
            s1 = ao1.mkScatter()
            s2 = ao2.mkScatter()
        except Exception as e:
            print("WARNING! Could not create a '%s' scatter for comparison (%s)" % (path, type(e).__name__))
            break

        ## Check for compatible dimensionalities (should already be ok, but just making sure)
        if s1.dim() != s2.dim():
            THISCLEAN = False
            log("Data objects with path '%s' have different scatter dimensions (%d and %d) in %s and %s" % \
                (path, s1.dim(), s2.dim(), filenames[0], filenames[1]))
            break

        ## Compare the numbers of points/bins
        if s1.numPoints() != s2.numPoints():
            THISCLEAN = False
            log("Data objects with path '%s' have different numbers of points (%d and %d) in %s and %s" % \
                (path, s1.numPoints(), s2.numPoints(), filenames[0], filenames[1]))
            break

        ## Compare the numeric values of each point
        premsg = "Data points differ for data objects with path '%s' in %s and %s:\n" % (path, filenames[0], filenames[1])
        msgs = []
        for i, (p1, p2) in enumerate(zip(s1.points(), s2.points())):
            # TODO: do this more nicely when point.val(int) and point.err(int) are mapped into Python
            ok = True
            diffaxis = []
            if p1.dim() >= 1 and not (eq(p1.x(), p2.x()) and eq(p1.xErrs(), p2.xErrs())):
                ok = False
                diffaxis.append('x')
            if p1.dim() >= 2 and not (eq(p1.y(), p2.y()) and eq(p1.yErrs(), p2.yErrs())):
                ok = False
                diffaxis.append('y')
            if p1.dim() >= 3 and not (eq(p1.z(), p2.z()) and eq(p1.zErrs(), p2.zErrs())):
                ok = False
                diffaxis.append('z')
            if not ok:
                msgs.append("  Point #%d (different %s): %s vs. %s" % (i, ", ".join(diffaxis), ptstr(p1), ptstr(p2)))

            if args.DEV_HISTO_FILE:
                ## Fill deviation histogram
                ## TODO: treat asymmetric uncs correctly
                try:
                    if p1.dim() == 1 and p1.xErrAvg() > 0 and p2.xErrAvg() > 0:
                        devhist.fill((p1.x()-p2.x())/sqrt(p1.xErrAvg()**2 + p2.xErrAvg()**2))
                    if p1.dim() == 2 and p1.yErrAvg() > 0 and p2.yErrAvg() > 0:
                        devhist.fill((p1.y()-p2.y())/sqrt(p1.yErrAvg()**2 + p2.yErrAvg()**2))
                    if p1.dim() == 3 and p1.zErrAvg() > 0 and p2.zErrAvg() > 0:
                        devhist.fill((p1.z()-p2.z())/sqrt(p1.zErrAvg()**2 + p2.zErrAvg()**2))
                except:
                    print("deviation calculation failed in ", path, p1.x())

        if msgs:
            THISCLEAN = False
            log(premsg + "\n".join(msgs))
            break

        ## Compare the annotations (not done by default)
        if args.ANNOTATIONS:
            for annotation in set(ao1.annotations() + ao2.annotations()):
                if ao1.annotation(annotation) != ao2.annotation(annotation):
                    THISCLEAN = False
                    log("Data objects with path '%s' have different '%s' annotations ('%s' and '%s') in %s and %s" %
                        (path, annotation, ao1.annotation(annotation), ao2.annotation(annotation), filenames[0], filenames[1]))

        ## Make sure to exit the hack while-loop if clean
        break

    ## List failing path in list mode
    if not THISCLEAN and args.LIST and not args.QUIET:
        print(path)

    ## Aggregate total cleanliness flag
    CLEAN &= THISCLEAN

if args.DEV_HISTO_FILE:
    yoda.write([devhist], args.DEV_HISTO_FILE)

if not CLEAN:
    sys.exit(1)

# sys.exit(0)
