Adds cross-chi2 comparison script

This commit is contained in:
LSaldyt
2017-11-19 11:25:11 -07:00
parent ee20de8297
commit 402e66409a
2 changed files with 41 additions and 9 deletions

View File

@ -29,29 +29,38 @@ def chi_squared(actual, expected):
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
else:
chiSquared += (O - E) ** 2 / E
return chiSquared
return degreesFreedom, chiSquared
def chi_squared_test(actual, expected):
df, chiSquared = chi_squared(actual, expected)
if chiSquared >= _chiSquared_table[df]:
print('Significant difference between expected and actual answer distributions: \n' +
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, df))
return False
return True
def cross_formula_chi_squared(actualDict, expectedDict):
for ka, actual in actualDict.items():
for ke, expected in expectedDict.items():
print('Comparing {} with {}'.format(ka, ke))
chiSquared = chi_squared(actual, expected)
if chiSquared >= _chiSquared_table[degreesFreedom]:
print('Significant difference between expected and actual answer distributions: \n' +
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
chi_squared_test(actual, expected)
def cross_chi_squared(problemSets):
for i, problemSetA in enumerate(problemSets):
for problemSetB in problemSets[i + 1:]:
for problemA in problemSetA:
for problemB in problemSetB:
answersA = problemA.distributions
answersB = problemB.distributions
cross_formula_chi_squared(answersA, answersB)
if (problemA.initial == problemB.initial and
problemA.modified == problemB.modified and
problemA.target == problemB.target):
answersA = problemA.distributions
answersB = problemB.distributions
cross_formula_chi_squared(answersA, answersB)
def iso_chi_squared(actualDict, expectedDict):
for key in expectedDict.keys():
assert key in actualDict, 'The key {} was not tested'.format(key)
actual = actualDict[key]
expected = expectedDict[key]
chi_squared_test(actual, expected)

23
cross_compare.py Executable file
View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
import sys
import pickle
from copycat import Problem
from copycat.statistics import cross_chi_squared
def compare_sets():
pass
def main(args):
branchProblemSets = dict()
problemSets = []
for filename in args:
with open(filename, 'rb') as infile:
pSet = pickle.load(infile)
branchProblemSets[filename] = pSet
problemSets.append(pSet)
cross_chi_squared(problemSets)
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))