Generalizes X^2 test to use G value
This commit is contained in:
@ -1,7 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from pprint import pprint
|
||||
# CHI2 values for n degrees freedom
|
||||
_chiSquared_table = {
|
||||
from pprint import pprint
|
||||
from math import log
|
||||
|
||||
# comparison values for n degrees freedom
|
||||
# These values are useable for both the chi^2 and G tests
|
||||
|
||||
_ptable = {
|
||||
1:3.841,
|
||||
2:5.991,
|
||||
3:7.815,
|
||||
@ -11,13 +15,36 @@ _chiSquared_table = {
|
||||
7:14.067,
|
||||
8:15.507,
|
||||
9:16.919,
|
||||
10:18.307
|
||||
10:18.307,
|
||||
11:19.7,
|
||||
12:21,
|
||||
13:22.4,
|
||||
14:23.7,
|
||||
15:25,
|
||||
16:26.3
|
||||
}
|
||||
|
||||
class ChiSquaredException(Exception):
|
||||
pass
|
||||
def g_value(actual, expected):
|
||||
# G = 2 * sum(Oi * ln(Oi/Ei))
|
||||
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
||||
degreesFreedom = len(answerKeys)
|
||||
G = 0
|
||||
|
||||
def chi_squared(actual, expected):
|
||||
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
||||
|
||||
for k in answerKeys:
|
||||
E = get_count(k, expected)
|
||||
O = get_count(k, actual)
|
||||
if E == 0:
|
||||
print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
||||
elif O == 0:
|
||||
print(' Warning! O = {}'.format(O))
|
||||
else:
|
||||
G += O * log(O/E)
|
||||
G *= 2
|
||||
return degreesFreedom, G
|
||||
|
||||
def chi_value(actual, expected):
|
||||
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
||||
degreesFreedom = len(answerKeys)
|
||||
chiSquared = 0
|
||||
@ -33,67 +60,21 @@ def chi_squared(actual, expected):
|
||||
chiSquared += (O - E) ** 2 / E
|
||||
return degreesFreedom, chiSquared
|
||||
|
||||
def chi_squared_diff(actual, expected):
|
||||
df, chiSquared = chi_squared(actual, expected)
|
||||
return (chiSquared < _chiSquared_table[df])
|
||||
def dist_test(actual, expected, calculation):
|
||||
df, p = calculation(actual, expected)
|
||||
if df not in _ptable:
|
||||
raise Exception('{} degrees of freedom does not have a corresponding chi squared value.' + \
|
||||
' Please look up the value and add it to the table in copycat/statistics.py'.format(df))
|
||||
return (p < _ptable[df])
|
||||
|
||||
def chi_squared_test(actual, expected, show=True):
|
||||
df, chiSquared = chi_squared(actual, expected)
|
||||
|
||||
if chiSquared >= _chiSquared_table[df]:
|
||||
if show:
|
||||
print('Significant difference between expected and actual answer distributions: \n' +
|
||||
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, df))
|
||||
return False
|
||||
return True
|
||||
|
||||
def cross_formula_chi_squared(actualDict, expectedDict):
|
||||
failures = []
|
||||
for ka, actual in actualDict.items():
|
||||
for ke, expected in expectedDict.items():
|
||||
print(' Comparing {} with {}: '.format(ka, ke), end='')
|
||||
if not chi_squared_test(actual, expected, show=False):
|
||||
failures.append('{}:{}'.format(ka, ke))
|
||||
print(' Failed.')
|
||||
else:
|
||||
print(' Succeeded.')
|
||||
return failures
|
||||
|
||||
def cross_formula_chi_squared_table(actualDict, expectedDict):
|
||||
def cross_formula_table(actualDict, expectedDict, calculation):
|
||||
data = dict()
|
||||
for ka, actual in actualDict.items():
|
||||
for ke, expected in expectedDict.items():
|
||||
#df, chiSquared = chi_squared(actual, expected)
|
||||
data[(ka, ke)] = chi_squared_diff(actual, expected)
|
||||
#data[(ka, ke)] = (df, chiSquared)
|
||||
data[(ka, ke)] = dist_test(actual, expected, calculation)
|
||||
return data
|
||||
|
||||
def cross_chi_squared(problemSets):
|
||||
failures = defaultdict(list)
|
||||
for i, (a, problemSetA) in enumerate(problemSets):
|
||||
for b, problemSetB in problemSets[i + 1:]:
|
||||
for problemA in problemSetA:
|
||||
for problemB in problemSetB:
|
||||
if (problemA.initial == problemB.initial and
|
||||
problemA.modified == problemB.modified and
|
||||
problemA.target == problemB.target):
|
||||
answersA = problemA.distributions
|
||||
answersB = problemB.distributions
|
||||
print('-' * 80)
|
||||
print('\n')
|
||||
print('{} x {}'.format(a, b))
|
||||
problemString = '{} x {} for: {}:{}::{}:_\n'.format(a,
|
||||
b,
|
||||
problemA.initial,
|
||||
problemA.modified,
|
||||
problemA.target)
|
||||
failures[problemString].append(cross_formula_chi_squared(answersA, answersB))
|
||||
pprint(answersA)
|
||||
pprint(answersB)
|
||||
print('\n')
|
||||
return failures
|
||||
|
||||
def cross_chi_squared_table(problemSets):
|
||||
def cross_table(problemSets, calculation=g_value):
|
||||
table = defaultdict(dict)
|
||||
for i, (a, problemSetA) in enumerate(problemSets):
|
||||
for b, problemSetB in problemSets[i + 1:]:
|
||||
@ -104,7 +85,11 @@ def cross_chi_squared_table(problemSets):
|
||||
problemA.target == problemB.target):
|
||||
answersA = problemA.distributions
|
||||
answersB = problemB.distributions
|
||||
table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB)
|
||||
table[(problemA.initial,
|
||||
problemA.modified,
|
||||
problemA.target)][(a, b)] = (
|
||||
cross_formula_table(
|
||||
answersA, answersB, calculation))
|
||||
return table
|
||||
|
||||
def iso_chi_squared(actualDict, expectedDict):
|
||||
@ -112,4 +97,6 @@ def iso_chi_squared(actualDict, expectedDict):
|
||||
assert key in actualDict, 'The key {} was not tested'.format(key)
|
||||
actual = actualDict[key]
|
||||
expected = expectedDict[key]
|
||||
chi_squared_test(actual, expected)
|
||||
if not dist_test(actual, expected, g_value):
|
||||
raise Exception('Value of G higher than expected')
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import pickle
|
||||
from pprint import pprint
|
||||
|
||||
from copycat import Problem
|
||||
from copycat.statistics import cross_chi_squared, cross_chi_squared_table
|
||||
from copycat.statistics import cross_table
|
||||
|
||||
def compare_sets():
|
||||
pass
|
||||
@ -18,9 +18,7 @@ def main(args):
|
||||
pSet = pickle.load(infile)
|
||||
branchProblemSets[filename] = pSet
|
||||
problemSets.append((filename, pSet))
|
||||
pprint(problemSets)
|
||||
pprint(cross_chi_squared(problemSets))
|
||||
crossTable = cross_chi_squared_table(problemSets)
|
||||
crossTable = cross_table(problemSets)
|
||||
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
|
||||
|
||||
tableItems = key_sorted_items(crossTable)
|
||||
|
||||
Reference in New Issue
Block a user