#!/usr/bin/python3
##
# CU - C unit testing framework
# ---------------------------------
# Copyright (c)2007,2008,2012 Daniel Fiser <danfis@danfis.cz>
#
#  This file is part of CU.
#
#  Distributed under the OSI-approved BSD License (the "License");
#  see accompanying file BDS-LICENSE for details or see
#  <http://www.opensource.org/licenses/bsd-license.php>.
#
#  This software is distributed WITHOUT ANY WARRANTY; without even the
#  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#  See the License for more information.
#

from subprocess import Popen, PIPE
import os
import re
import sys
import math
from getopt import gnu_getopt, GetoptError

EPS = 6E-4
BASE_DIR = "."
MAX_DIFF_LINES = 20
EXACT = False

PROGRESS_ON = True
MSG_BASE = ""

def pr(s, flush = False):
    sys.stdout.write(s)
    if flush:
        sys.stdout.flush()

def prnl(s, flush = False):
    pr(s + "\n", flush = flush)

class Hunk:
    """ This class represents one hunk from diff. """

    def __init__(self):
        self.added = []
        self.deleted = []
        self.lines = []

        # to identify lines with floating point numbers
        self.re_is_num = re.compile("^.*[0-9].*$")

        # pattern to match floating point number
        self.num_pattern = r"-?(?:(?:[0-9]+(?:\.[0-9]*)?)|(?:\.[0-9]+))(?:[eE]-?[0-9]+)?"
        self.re_num = re.compile(self.num_pattern)

    def numLines(self):
        return len(self.lines)
    def numLinesAdded(self):
        return len(self.added)
    def numLinesDeleted(self):
        return len(self.deleted)

    def addLineAdded(self, line):
        self.added.append(line)
    def addLineDeleted(self, line):
        self.deleted.append(line)
    def addLine(self, line):
        self.lines.append(line)

    def getLines(self):
        return self.lines
    def getLinesAdded(self):
        return self.added
    def getLinesDeleted(self):
        return self.deleted

    def __eq(self, num1, num2):
        """ Returns True if num1 equals to num2 with respect to EPS
            (defined above) """
        return math.fabs(num1 - num2) < EPS

    def checkFloats(self):
        """ This method try to check if only difference between added and
            deleted lines of this hunk is different precission of floating
            point numbers
        """

        # If number of added and deleted lines differs, then there is more
        # differences that precission of floating point numbers
        if self.numLinesAdded() != self.numLinesDeleted():
            return False

        for i in range(0, self.numLinesAdded()):
            # if any line does not contain number - return False because
            # there must be more differences than in numbers
            if not self.re_is_num.match(self.added[i]) \
               or not self.re_is_num.match(self.deleted[i]):
               return False

            line1 = self.added[i]
            line2 = self.deleted[i]

            # Extract all floating point numbers from each line
            nums1 = self.re_num.findall(line1)
            nums2 = self.re_num.findall(line2)
            # and remove all empty strings
            nums1 = list(filter(lambda x: len(x) > 0, nums1))
            nums2 = list(filter(lambda x: len(x) > 0, nums2))

            # if length of list nums1 does not equal to length of nums2
            # return False
            if len(nums1) != len(nums2):
                return False

            # iterate trough all numbers
            for j in range(0, len(nums1)):
                # if numbers do not equal to each other return False
                if not self.__eq(float(nums1[j]), float(nums2[j])):
                    return False

            # compare the rest of lines
            line1 = self.re_num.sub("", line1)
            line2 = self.re_num.sub("", line2)
            if line1 != line2:
                return False

        # If it does not fail anywhere, added and deleted lines must be
        # same
        return True


class Diff:
    """ Represents whole diff. """

    def __init__(self):
        self.hunks = []
        self.lines = 0
        self.omitted_lines = 0

    def addHunk(self, hunk):
        self.hunks.append(hunk)
        self.lines += hunk.numLines()

    def numLines(self):
        return self.lines
    def numOmittedLines(self):
        return self.omitted_lines

    def getHunks(self):
        return self.hunks
    def numHunks(self):
        return len(self.hunks)

    def checkFloats(self):
        """ Will call method checkFloats on each hunk """
        hks = self.hunks[:]
        self.hunks = []
        self.lines = 0
        for h in hks:
            if not h.checkFloats():
                self.hunks.append(h)
                self.lines += h.numLines()
            else:
                self.omitted_lines += h.numLines()
            


class Parser:
    def __init__(self, fin):
        self.fin = fin
        self.line = ""
        self.diff = Diff()
        self.cur_hunk = None

        # to recognize beginning of hunk:
        self.re_hunk = re.compile(r"^[0-9]*(,[0-9]*){0,1}[a-zA-Z]?[0-9]*(,[0-9]*){0,1}$")

        self.re_added = re.compile(r"^> (.*)$")
        self.re_deleted = re.compile(r"^< (.*)$")

    def __readNextLine(self):
        self.line = self.fin.readline()
        self.line = self.line.decode('utf-8')
        if len(self.line) == 0:
            return False
        return True

    def parse(self):
        global PROGRESS_ON
        global MSG_BASE

        num_lines = 0
        while self.__readNextLine():
            # beggining of hunk
            if self.re_hunk.match(self.line):
                if self.cur_hunk is not None:
                    self.diff.addHunk(self.cur_hunk)
                self.cur_hunk = Hunk()
                self.cur_hunk.addLine(self.line)

            # line added
            match = self.re_added.match(self.line)
            if match is not None:
                self.cur_hunk.addLine(self.line)
                self.cur_hunk.addLineAdded(match.group(1))

            # line deleted
            match = self.re_deleted.match(self.line)
            if match is not None:
                self.cur_hunk.addLine(self.line)
                self.cur_hunk.addLineDeleted(match.group(1))

            num_lines += 1

            if PROGRESS_ON and num_lines % 50 == 0:
                pr(MSG_BASE + " [ %08d ]" % num_lines + "\r", flush = True)

        # last push to list of hunks
        if self.cur_hunk is not None:
            self.diff.addHunk(self.cur_hunk)

        if PROGRESS_ON:
            pr(MSG_BASE + "            \r", flush = True)
        
    def getDiff(self):
        return self.diff


def regressionFilesInDir():
    """ Returns sorted list of pairs of filenames where first name in pair
        is tmp. file and second corresponding file with saved regressions.
    """

    re_tmp_out_file = re.compile(r"tmp\.(.*\.out)")
    re_tmp_err_file = re.compile(r"tmp\.(.*\.err)")
    files = []

    all_files = os.listdir(".")
    all_files.sort()
    for file in all_files:
        res = re_tmp_out_file.match(file)
        if res is not None:
            fname = res.group(1)
            tmp = [file, ""]
            for file2 in all_files:
                if file2 == fname:
                    tmp = [file, file2,]
                    break
            files.append(tmp)

        res = re_tmp_err_file.match(file)
        if res is not None:
            fname = res.group(1)
            tmp = [file, ""]
            for file2 in all_files:
                if file2 == fname:
                    tmp = [file, file2,]
                    break
            files.append(tmp)

    return files


def MSG(s = "", wait = False):
    if wait:
        pr(s)
    else:
        prnl(s)
def MSGOK(prestr = "", str = "", poststr = ""):
    prnl(prestr + " \033[0;32m" + str + "\033[0;0m " + poststr)
def MSGFAIL(prestr = "", str = "", poststr = ""):
    prnl(prestr + " \033[0;31m" + str + "\033[0;0m " + poststr)
def MSGINFO(prestr = "", str = "", poststr = ""):
    prnl(prestr + " \033[0;33m" + str + "\033[0;0m " + poststr)
def dumpLines(lines, prefix = "", wait = False, max_lines = -1):
    line_num = 0
    if wait:
        for line in lines:
            pr(prefix + " " + line)
            line_num += 1
            if max_lines >= 0 and line_num > max_lines:
                break
    else:
        for line in lines:
            pr(prefix + " " + line)
            line_num += 1
            if max_lines >= 0 and line_num > max_lines:
                break

def main(files):
    global MSG_BASE

    # As first compute length of columns
    len1 = 0
    len2 = 0
    for filenames in files:
        if len(filenames[0]) > len1:
            len1 = len(filenames[0])
        if len(filenames[1]) > len2:
            len2 = len(filenames[1])

    for filenames in files:
        if len(filenames[1]) == 0:
            MSGFAIL("", "===", "Can't compare %s %s, bacause %s does not exist!" % \
                    (filenames[0], filenames[0][4:], filenames[0][4:]))
            continue

        cmd = ["diff", filenames[0], filenames[1]]
        MSG_BASE = "Comparing %s and %s" % \
                (filenames[0].ljust(len1) ,filenames[1].ljust(len2))
        if not PROGRESS_ON:
            pr(MSG_BASE, flush = True)

        pipe = Popen(cmd, stdout=PIPE)
        parser = Parser(pipe.stdout)
        parser.parse()
        diff = parser.getDiff()
        if not EXACT:
            diff.checkFloats()

        if PROGRESS_ON:
            pr(MSG_BASE)

        if diff.numHunks() == 0:
            MSGOK(" [", "OK", "]")
            if diff.numOmittedLines() > 0:
                MSGINFO(" -->", str(diff.numOmittedLines()) + " lines from diff omitted")
        else:
            MSGFAIL(" [", "FAILED", "]")
            if diff.numOmittedLines() > 0:
                MSGINFO(" -->", str(diff.numOmittedLines()) + " lines from diff omitted")
            MSGINFO(" -->", "Diff has " + str(diff.numLines()) + " lines")

            if diff.numLines() <= MAX_DIFF_LINES:
                MSGINFO(" -->", "Diff:")
                for h in diff.getHunks():
                    dumpLines(h.getLines(), "   |", True)
            else:
                MSGINFO(" -->", "Printing only first " + str(MAX_DIFF_LINES) + " lines:")
                lines = []
                for h in diff.getHunks():
                    lines += h.getLines()
                    if len(lines) > MAX_DIFF_LINES:
                        break;
                dumpLines(lines, "   |", True, MAX_DIFF_LINES)
                    
def usage():
    prnl("Usage: " + sys.argv[0] + " [ OPTIONS ] [ directory, [ directory, [ ... ] ] ]")
    prnl("")
    prnl(" OPTIONS:")
    prnl("     --help / -h       none   Print this help")
    prnl("     --exact / -e      none   Switch do exact comparasion of files")
    prnl("     --not-exact / -n  none   Switch do non exact comparasion of files (default behaviour)")
    prnl("     --max-diff-lines  int    Maximum of lines of diff which can be printed (default " + str(MAX_DIFF_LINES) + ")")
    prnl("     --eps             float  Precision of floating point numbers (epsilon) (default " + str(EPS) + ")")
    prnl("     --no-progress     none   Turn off progress bar")
    prnl("     --progress        none   Turn on progress bar (default)")
    prnl("")
    prnl(" This program is able to compare files with regressions generated by CU testsuites.")
    prnl(" You can specify directories which are to be searched for regression files.")
    prnl(" In non exact copmarasion mode (which is default), this program tries to compare")
    prnl(" floating point numbers in files with respect to specified precision (see --eps) and")
    prnl(" those lines which differ only in precission of floating point numbers are omitted.")
    prnl("")
    sys.exit(-1)



# Init:

# Set up base dir
BASE_DIR = os.getcwd()

# Parse command line options:
optlist, args = gnu_getopt(sys.argv[1:],
                           "hen",
                           ["help", "max-diff-lines=", "eps=", \
                            "exact", "not-exact", \
                            "no-progress", "progress"])
for opt in optlist:
    if opt[0] == "--help" or opt[0] == "-h":
        usage()
    if opt[0] == "--exact" or opt[0] == "-e":
        EXACT = True
    if opt[0] == "--not-exact" or opt[0] == "-n":
        EXACT = False
    if opt[0] == "--max-diff-lines":
        MAX_DIFF_LINES = int(opt[1])
    if opt[0] == "--eps":
        EPS = float(opt[1])
    if opt[0] == "--no-progress":
        PROGRESS_ON = False
    if opt[0] == "--progress":
        PROGRESS_ON = True

if len(args) == 0:
    files = regressionFilesInDir()
    main(files)
else:
    for dir in args:
        os.chdir(BASE_DIR)

        MSGINFO()
        MSGINFO("", "Processing directory '" + dir + "':")
        MSGINFO()
        try:
            os.chdir(dir)
        except:
            MSGFAIL("  -->", "Directory '" + dir + "' does not exist.")
        files = regressionFilesInDir()
        main(files)

sys.exit(0)
