Updates chi2 code for final comparison
This commit is contained in:
@ -33,6 +33,10 @@ def chi_squared(actual, expected):
|
||||
chiSquared += (O - E) ** 2 / E
|
||||
return degreesFreedom, chiSquared
|
||||
|
||||
def chi_squared_diff(actual, expected):
|
||||
df, chiSquared = chi_squared(actual, expected)
|
||||
return (chiSquared < _chiSquared_table[df])
|
||||
|
||||
def chi_squared_test(actual, expected, show=True):
|
||||
df, chiSquared = chi_squared(actual, expected)
|
||||
|
||||
@ -55,6 +59,15 @@ def cross_formula_chi_squared(actualDict, expectedDict):
|
||||
print(' Succeeded.')
|
||||
return failures
|
||||
|
||||
def cross_formula_chi_squared_table(actualDict, expectedDict):
|
||||
data = dict()
|
||||
for ka, actual in actualDict.items():
|
||||
for ke, expected in expectedDict.items():
|
||||
#df, chiSquared = chi_squared(actual, expected)
|
||||
data[(ka, ke)] = chi_squared_diff(actual, expected)
|
||||
#data[(ka, ke)] = (df, chiSquared)
|
||||
return data
|
||||
|
||||
def cross_chi_squared(problemSets):
|
||||
failures = defaultdict(list)
|
||||
for i, (a, problemSetA) in enumerate(problemSets):
|
||||
@ -80,6 +93,20 @@ def cross_chi_squared(problemSets):
|
||||
print('\n')
|
||||
return failures
|
||||
|
||||
def cross_chi_squared_table(problemSets):
|
||||
table = defaultdict(dict)
|
||||
for i, (a, problemSetA) in enumerate(problemSets):
|
||||
for b, problemSetB in problemSets[i + 1:]:
|
||||
for problemA in problemSetA:
|
||||
for problemB in problemSetB:
|
||||
if (problemA.initial == problemB.initial and
|
||||
problemA.modified == problemB.modified and
|
||||
problemA.target == problemB.target):
|
||||
answersA = problemA.distributions
|
||||
answersB = problemB.distributions
|
||||
table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB)
|
||||
return table
|
||||
|
||||
def iso_chi_squared(actualDict, expectedDict):
|
||||
for key in expectedDict.keys():
|
||||
assert key in actualDict, 'The key {} was not tested'.format(key)
|
||||
|
||||
@ -5,7 +5,7 @@ import pickle
|
||||
from pprint import pprint
|
||||
|
||||
from copycat import Problem
|
||||
from copycat.statistics import cross_chi_squared
|
||||
from copycat.statistics import cross_chi_squared, cross_chi_squared_table
|
||||
|
||||
def compare_sets():
|
||||
pass
|
||||
@ -18,7 +18,30 @@ def main(args):
|
||||
pSet = pickle.load(infile)
|
||||
branchProblemSets[filename] = pSet
|
||||
problemSets.append((filename, pSet))
|
||||
pprint(cross_chi_squared(problemSets))
|
||||
#pprint(problemSets)
|
||||
#pprint(cross_chi_squared(problemSets))
|
||||
crossTable = cross_chi_squared_table(problemSets)
|
||||
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
|
||||
|
||||
tableItems = key_sorted_items(crossTable)
|
||||
assert len(tableItems) > 0, 'Empty table'
|
||||
|
||||
with open('output/cross_compare.csv', 'w') as outfile:
|
||||
headKey, headSubDict = tableItems[0]
|
||||
cells = ['problems x variants']
|
||||
for subkey, subsubdict in key_sorted_items(headSubDict):
|
||||
for subsubkey, result in key_sorted_items(subsubdict):
|
||||
cells.append('{} x {} for {} x {}'.format(*subkey, *subsubkey))
|
||||
outfile.write(','.join(cells) + '\n')
|
||||
|
||||
for key, subdict in tableItems:
|
||||
cells = []
|
||||
problem = '{}:{}::{}:_'.format(*key)
|
||||
cells.append(problem)
|
||||
for subkey, subsubdict in key_sorted_items(subdict):
|
||||
for subsubkey, result in key_sorted_items(subsubdict):
|
||||
cells.append(str(result))
|
||||
outfile.write(','.join(cells) + '\n')
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user