From 402e66409a40f49f9adf8f53c57312ca03506628 Mon Sep 17 00:00:00 2001 From: LSaldyt Date: Sun, 19 Nov 2017 11:25:11 -0700 Subject: [PATCH] Adds cross-chi2 comparison script --- copycat/statistics.py | 27 ++++++++++++++++++--------- cross_compare.py | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) create mode 100755 cross_compare.py diff --git a/copycat/statistics.py b/copycat/statistics.py index 4f1ffe3..0f8a74c 100644 --- a/copycat/statistics.py +++ b/copycat/statistics.py @@ -29,29 +29,38 @@ def chi_squared(actual, expected): print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) else: chiSquared += (O - E) ** 2 / E - return chiSquared + return degreesFreedom, chiSquared + +def chi_squared_test(actual, expected): + df, chiSquared = chi_squared(actual, expected) + + if chiSquared >= _chiSquared_table[df]: + print('Significant difference between expected and actual answer distributions: \n' + + 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, df)) + return False + return True def cross_formula_chi_squared(actualDict, expectedDict): for ka, actual in actualDict.items(): for ke, expected in expectedDict.items(): print('Comparing {} with {}'.format(ka, ke)) - chiSquared = chi_squared(actual, expected) - - if chiSquared >= _chiSquared_table[degreesFreedom]: - print('Significant difference between expected and actual answer distributions: \n' + - 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) + chi_squared_test(actual, expected) def cross_chi_squared(problemSets): for i, problemSetA in enumerate(problemSets): for problemSetB in problemSets[i + 1:]: for problemA in problemSetA: for problemB in problemSetB: - answersA = problemA.distributions - answersB = problemB.distributions - cross_formula_chi_squared(answersA, answersB) + if (problemA.initial == problemB.initial and + problemA.modified == problemB.modified and + problemA.target == problemB.target): + answersA = problemA.distributions + answersB = problemB.distributions + cross_formula_chi_squared(answersA, answersB) def iso_chi_squared(actualDict, expectedDict): for key in expectedDict.keys(): assert key in actualDict, 'The key {} was not tested'.format(key) actual = actualDict[key] expected = expectedDict[key] + chi_squared_test(actual, expected) diff --git a/cross_compare.py b/cross_compare.py new file mode 100755 index 0000000..734c86c --- /dev/null +++ b/cross_compare.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +import sys +import pickle + +from copycat import Problem +from copycat.statistics import cross_chi_squared + +def compare_sets(): + pass + +def main(args): + branchProblemSets = dict() + problemSets = [] + for filename in args: + with open(filename, 'rb') as infile: + pSet = pickle.load(infile) + branchProblemSets[filename] = pSet + problemSets.append(pSet) + cross_chi_squared(problemSets) + return 0 + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:]))