Updates chi2 code for final comparison
This commit is contained in:
@ -33,6 +33,10 @@ def chi_squared(actual, expected):
|
|||||||
chiSquared += (O - E) ** 2 / E
|
chiSquared += (O - E) ** 2 / E
|
||||||
return degreesFreedom, chiSquared
|
return degreesFreedom, chiSquared
|
||||||
|
|
||||||
|
def chi_squared_diff(actual, expected):
|
||||||
|
df, chiSquared = chi_squared(actual, expected)
|
||||||
|
return (chiSquared < _chiSquared_table[df])
|
||||||
|
|
||||||
def chi_squared_test(actual, expected, show=True):
|
def chi_squared_test(actual, expected, show=True):
|
||||||
df, chiSquared = chi_squared(actual, expected)
|
df, chiSquared = chi_squared(actual, expected)
|
||||||
|
|
||||||
@ -55,13 +59,22 @@ def cross_formula_chi_squared(actualDict, expectedDict):
|
|||||||
print(' Succeeded.')
|
print(' Succeeded.')
|
||||||
return failures
|
return failures
|
||||||
|
|
||||||
|
def cross_formula_chi_squared_table(actualDict, expectedDict):
|
||||||
|
data = dict()
|
||||||
|
for ka, actual in actualDict.items():
|
||||||
|
for ke, expected in expectedDict.items():
|
||||||
|
#df, chiSquared = chi_squared(actual, expected)
|
||||||
|
data[(ka, ke)] = chi_squared_diff(actual, expected)
|
||||||
|
#data[(ka, ke)] = (df, chiSquared)
|
||||||
|
return data
|
||||||
|
|
||||||
def cross_chi_squared(problemSets):
|
def cross_chi_squared(problemSets):
|
||||||
failures = defaultdict(list)
|
failures = defaultdict(list)
|
||||||
for i, (a, problemSetA) in enumerate(problemSets):
|
for i, (a, problemSetA) in enumerate(problemSets):
|
||||||
for b, problemSetB in problemSets[i + 1:]:
|
for b, problemSetB in problemSets[i + 1:]:
|
||||||
for problemA in problemSetA:
|
for problemA in problemSetA:
|
||||||
for problemB in problemSetB:
|
for problemB in problemSetB:
|
||||||
if (problemA.initial == problemB.initial and
|
if (problemA.initial == problemB.initial and
|
||||||
problemA.modified == problemB.modified and
|
problemA.modified == problemB.modified and
|
||||||
problemA.target == problemB.target):
|
problemA.target == problemB.target):
|
||||||
answersA = problemA.distributions
|
answersA = problemA.distributions
|
||||||
@ -80,6 +93,20 @@ def cross_chi_squared(problemSets):
|
|||||||
print('\n')
|
print('\n')
|
||||||
return failures
|
return failures
|
||||||
|
|
||||||
|
def cross_chi_squared_table(problemSets):
|
||||||
|
table = defaultdict(dict)
|
||||||
|
for i, (a, problemSetA) in enumerate(problemSets):
|
||||||
|
for b, problemSetB in problemSets[i + 1:]:
|
||||||
|
for problemA in problemSetA:
|
||||||
|
for problemB in problemSetB:
|
||||||
|
if (problemA.initial == problemB.initial and
|
||||||
|
problemA.modified == problemB.modified and
|
||||||
|
problemA.target == problemB.target):
|
||||||
|
answersA = problemA.distributions
|
||||||
|
answersB = problemB.distributions
|
||||||
|
table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB)
|
||||||
|
return table
|
||||||
|
|
||||||
def iso_chi_squared(actualDict, expectedDict):
|
def iso_chi_squared(actualDict, expectedDict):
|
||||||
for key in expectedDict.keys():
|
for key in expectedDict.keys():
|
||||||
assert key in actualDict, 'The key {} was not tested'.format(key)
|
assert key in actualDict, 'The key {} was not tested'.format(key)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import pickle
|
|||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
from copycat import Problem
|
from copycat import Problem
|
||||||
from copycat.statistics import cross_chi_squared
|
from copycat.statistics import cross_chi_squared, cross_chi_squared_table
|
||||||
|
|
||||||
def compare_sets():
|
def compare_sets():
|
||||||
pass
|
pass
|
||||||
@ -18,7 +18,30 @@ def main(args):
|
|||||||
pSet = pickle.load(infile)
|
pSet = pickle.load(infile)
|
||||||
branchProblemSets[filename] = pSet
|
branchProblemSets[filename] = pSet
|
||||||
problemSets.append((filename, pSet))
|
problemSets.append((filename, pSet))
|
||||||
pprint(cross_chi_squared(problemSets))
|
#pprint(problemSets)
|
||||||
|
#pprint(cross_chi_squared(problemSets))
|
||||||
|
crossTable = cross_chi_squared_table(problemSets)
|
||||||
|
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
|
||||||
|
|
||||||
|
tableItems = key_sorted_items(crossTable)
|
||||||
|
assert len(tableItems) > 0, 'Empty table'
|
||||||
|
|
||||||
|
with open('output/cross_compare.csv', 'w') as outfile:
|
||||||
|
headKey, headSubDict = tableItems[0]
|
||||||
|
cells = ['problems x variants']
|
||||||
|
for subkey, subsubdict in key_sorted_items(headSubDict):
|
||||||
|
for subsubkey, result in key_sorted_items(subsubdict):
|
||||||
|
cells.append('{} x {} for {} x {}'.format(*subkey, *subsubkey))
|
||||||
|
outfile.write(','.join(cells) + '\n')
|
||||||
|
|
||||||
|
for key, subdict in tableItems:
|
||||||
|
cells = []
|
||||||
|
problem = '{}:{}::{}:_'.format(*key)
|
||||||
|
cells.append(problem)
|
||||||
|
for subkey, subsubdict in key_sorted_items(subdict):
|
||||||
|
for subsubkey, result in key_sorted_items(subsubdict):
|
||||||
|
cells.append(str(result))
|
||||||
|
outfile.write(','.join(cells) + '\n')
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user