diff --git a/copycat/statistics.py b/copycat/statistics.py index c085943..948eb4d 100644 --- a/copycat/statistics.py +++ b/copycat/statistics.py @@ -1,3 +1,5 @@ +from collections import defaultdict +from pprint import pprint # CHI2 values for n degrees freedom _chiSquared_table = { 1:3.841, @@ -26,7 +28,7 @@ def chi_squared(actual, expected): E = get_count(k, expected) O = get_count(k, actual) if E == 0: - print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) + print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O)) else: chiSquared += (O - E) ** 2 / E return degreesFreedom, chiSquared @@ -42,15 +44,19 @@ def chi_squared_test(actual, expected, show=True): return True def cross_formula_chi_squared(actualDict, expectedDict): + failures = [] for ka, actual in actualDict.items(): for ke, expected in expectedDict.items(): - print('Comparing {} with {}: '.format(ka, ke), end='') + print(' Comparing {} with {}: '.format(ka, ke), end='') if not chi_squared_test(actual, expected, show=False): - print('Failed.') + failures.append('{}:{}'.format(ka, ke)) + print(' Failed.') else: - print('Succeeded.') + print(' Succeeded.') + return failures def cross_chi_squared(problemSets): + failures = defaultdict(list) for i, (a, problemSetA) in enumerate(problemSets): for b, problemSetB in problemSets[i + 1:]: for problemA in problemSetA: @@ -63,11 +69,16 @@ def cross_chi_squared(problemSets): print('-' * 80) print('\n') print('{} x {}'.format(a, b)) - print('Problem: {}:{}::{}:_'.format(problemA.initial, + problemString = '{} x {} for: {}:{}::{}:_\n'.format(a, + b, + problemA.initial, problemA.modified, - problemA.target)) - cross_formula_chi_squared(answersA, answersB) + problemA.target) + failures[problemString].append(cross_formula_chi_squared(answersA, answersB)) + pprint(answersA) + pprint(answersB) print('\n') + return failures def iso_chi_squared(actualDict, expectedDict): for key in expectedDict.keys(): diff --git a/cross_compare.py b/cross_compare.py index 500d2ec..c35e3cd 100755 --- a/cross_compare.py +++ b/cross_compare.py @@ -2,6 +2,8 @@ import sys import pickle +from pprint import pprint + from copycat import Problem from copycat.statistics import cross_chi_squared @@ -16,7 +18,7 @@ def main(args): pSet = pickle.load(infile) branchProblemSets[filename] = pSet problemSets.append((filename, pSet)) - cross_chi_squared(problemSets) + pprint(cross_chi_squared(problemSets)) return 0 if __name__ == '__main__':