diff --git a/copycat/statistics.py b/copycat/statistics.py index 948eb4d..5a8c660 100644 --- a/copycat/statistics.py +++ b/copycat/statistics.py @@ -33,6 +33,10 @@ def chi_squared(actual, expected): chiSquared += (O - E) ** 2 / E return degreesFreedom, chiSquared +def chi_squared_diff(actual, expected): + df, chiSquared = chi_squared(actual, expected) + return (chiSquared < _chiSquared_table[df]) + def chi_squared_test(actual, expected, show=True): df, chiSquared = chi_squared(actual, expected) @@ -55,13 +59,22 @@ def cross_formula_chi_squared(actualDict, expectedDict): print(' Succeeded.') return failures +def cross_formula_chi_squared_table(actualDict, expectedDict): + data = dict() + for ka, actual in actualDict.items(): + for ke, expected in expectedDict.items(): + #df, chiSquared = chi_squared(actual, expected) + data[(ka, ke)] = chi_squared_diff(actual, expected) + #data[(ka, ke)] = (df, chiSquared) + return data + def cross_chi_squared(problemSets): failures = defaultdict(list) for i, (a, problemSetA) in enumerate(problemSets): for b, problemSetB in problemSets[i + 1:]: for problemA in problemSetA: for problemB in problemSetB: - if (problemA.initial == problemB.initial and + if (problemA.initial == problemB.initial and problemA.modified == problemB.modified and problemA.target == problemB.target): answersA = problemA.distributions @@ -80,6 +93,20 @@ def cross_chi_squared(problemSets): print('\n') return failures +def cross_chi_squared_table(problemSets): + table = defaultdict(dict) + for i, (a, problemSetA) in enumerate(problemSets): + for b, problemSetB in problemSets[i + 1:]: + for problemA in problemSetA: + for problemB in problemSetB: + if (problemA.initial == problemB.initial and + problemA.modified == problemB.modified and + problemA.target == problemB.target): + answersA = problemA.distributions + answersB = problemB.distributions + table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB) + return table + def iso_chi_squared(actualDict, expectedDict): for key in expectedDict.keys(): assert key in actualDict, 'The key {} was not tested'.format(key) diff --git a/cross_compare.py b/cross_compare.py index c35e3cd..de720d7 100755 --- a/cross_compare.py +++ b/cross_compare.py @@ -5,7 +5,7 @@ import pickle from pprint import pprint from copycat import Problem -from copycat.statistics import cross_chi_squared +from copycat.statistics import cross_chi_squared, cross_chi_squared_table def compare_sets(): pass @@ -18,7 +18,30 @@ def main(args): pSet = pickle.load(infile) branchProblemSets[filename] = pSet problemSets.append((filename, pSet)) - pprint(cross_chi_squared(problemSets)) + #pprint(problemSets) + #pprint(cross_chi_squared(problemSets)) + crossTable = cross_chi_squared_table(problemSets) + key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0]) + + tableItems = key_sorted_items(crossTable) + assert len(tableItems) > 0, 'Empty table' + + with open('output/cross_compare.csv', 'w') as outfile: + headKey, headSubDict = tableItems[0] + cells = ['problems x variants'] + for subkey, subsubdict in key_sorted_items(headSubDict): + for subsubkey, result in key_sorted_items(subsubdict): + cells.append('{} x {} for {} x {}'.format(*subkey, *subsubkey)) + outfile.write(','.join(cells) + '\n') + + for key, subdict in tableItems: + cells = [] + problem = '{}:{}::{}:_'.format(*key) + cells.append(problem) + for subkey, subsubdict in key_sorted_items(subdict): + for subsubkey, result in key_sorted_items(subsubdict): + cells.append(str(result)) + outfile.write(','.join(cells) + '\n') return 0 if __name__ == '__main__':