Updates chi2 code for final comparison

This commit is contained in:
LSaldyt
2017-12-09 15:15:06 -07:00
parent a8b9675d2f
commit 28e1ddae30
2 changed files with 53 additions and 3 deletions

View File

@ -33,6 +33,10 @@ def chi_squared(actual, expected):
chiSquared += (O - E) ** 2 / E chiSquared += (O - E) ** 2 / E
return degreesFreedom, chiSquared return degreesFreedom, chiSquared
def chi_squared_diff(actual, expected):
df, chiSquared = chi_squared(actual, expected)
return (chiSquared < _chiSquared_table[df])
def chi_squared_test(actual, expected, show=True): def chi_squared_test(actual, expected, show=True):
df, chiSquared = chi_squared(actual, expected) df, chiSquared = chi_squared(actual, expected)
@ -55,13 +59,22 @@ def cross_formula_chi_squared(actualDict, expectedDict):
print(' Succeeded.') print(' Succeeded.')
return failures return failures
def cross_formula_chi_squared_table(actualDict, expectedDict):
data = dict()
for ka, actual in actualDict.items():
for ke, expected in expectedDict.items():
#df, chiSquared = chi_squared(actual, expected)
data[(ka, ke)] = chi_squared_diff(actual, expected)
#data[(ka, ke)] = (df, chiSquared)
return data
def cross_chi_squared(problemSets): def cross_chi_squared(problemSets):
failures = defaultdict(list) failures = defaultdict(list)
for i, (a, problemSetA) in enumerate(problemSets): for i, (a, problemSetA) in enumerate(problemSets):
for b, problemSetB in problemSets[i + 1:]: for b, problemSetB in problemSets[i + 1:]:
for problemA in problemSetA: for problemA in problemSetA:
for problemB in problemSetB: for problemB in problemSetB:
if (problemA.initial == problemB.initial and if (problemA.initial == problemB.initial and
problemA.modified == problemB.modified and problemA.modified == problemB.modified and
problemA.target == problemB.target): problemA.target == problemB.target):
answersA = problemA.distributions answersA = problemA.distributions
@ -80,6 +93,20 @@ def cross_chi_squared(problemSets):
print('\n') print('\n')
return failures return failures
def cross_chi_squared_table(problemSets):
table = defaultdict(dict)
for i, (a, problemSetA) in enumerate(problemSets):
for b, problemSetB in problemSets[i + 1:]:
for problemA in problemSetA:
for problemB in problemSetB:
if (problemA.initial == problemB.initial and
problemA.modified == problemB.modified and
problemA.target == problemB.target):
answersA = problemA.distributions
answersB = problemB.distributions
table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB)
return table
def iso_chi_squared(actualDict, expectedDict): def iso_chi_squared(actualDict, expectedDict):
for key in expectedDict.keys(): for key in expectedDict.keys():
assert key in actualDict, 'The key {} was not tested'.format(key) assert key in actualDict, 'The key {} was not tested'.format(key)

View File

@ -5,7 +5,7 @@ import pickle
from pprint import pprint from pprint import pprint
from copycat import Problem from copycat import Problem
from copycat.statistics import cross_chi_squared from copycat.statistics import cross_chi_squared, cross_chi_squared_table
def compare_sets(): def compare_sets():
pass pass
@ -18,7 +18,30 @@ def main(args):
pSet = pickle.load(infile) pSet = pickle.load(infile)
branchProblemSets[filename] = pSet branchProblemSets[filename] = pSet
problemSets.append((filename, pSet)) problemSets.append((filename, pSet))
pprint(cross_chi_squared(problemSets)) #pprint(problemSets)
#pprint(cross_chi_squared(problemSets))
crossTable = cross_chi_squared_table(problemSets)
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
tableItems = key_sorted_items(crossTable)
assert len(tableItems) > 0, 'Empty table'
with open('output/cross_compare.csv', 'w') as outfile:
headKey, headSubDict = tableItems[0]
cells = ['problems x variants']
for subkey, subsubdict in key_sorted_items(headSubDict):
for subsubkey, result in key_sorted_items(subsubdict):
cells.append('{} x {} for {} x {}'.format(*subkey, *subsubkey))
outfile.write(','.join(cells) + '\n')
for key, subdict in tableItems:
cells = []
problem = '{}:{}::{}:_'.format(*key)
cells.append(problem)
for subkey, subsubdict in key_sorted_items(subdict):
for subsubkey, result in key_sorted_items(subsubdict):
cells.append(str(result))
outfile.write(','.join(cells) + '\n')
return 0 return 0
if __name__ == '__main__': if __name__ == '__main__':