Generalizes X^2 test to use G value

This commit is contained in:
LSaldyt
2018-01-12 14:19:11 -07:00
parent aed399ef34
commit 0a0369f5e1
2 changed files with 53 additions and 68 deletions

View File

@ -1,7 +1,11 @@
from collections import defaultdict from collections import defaultdict
from pprint import pprint from pprint import pprint
# CHI2 values for n degrees freedom from math import log
_chiSquared_table = {
# comparison values for n degrees freedom
# These values are useable for both the chi^2 and G tests
_ptable = {
1:3.841, 1:3.841,
2:5.991, 2:5.991,
3:7.815, 3:7.815,
@ -11,13 +15,36 @@ _chiSquared_table = {
7:14.067, 7:14.067,
8:15.507, 8:15.507,
9:16.919, 9:16.919,
10:18.307 10:18.307,
11:19.7,
12:21,
13:22.4,
14:23.7,
15:25,
16:26.3
} }
class ChiSquaredException(Exception): def g_value(actual, expected):
pass # G = 2 * sum(Oi * ln(Oi/Ei))
answerKeys = set(list(actual.keys()) + list(expected.keys()))
degreesFreedom = len(answerKeys)
G = 0
def chi_squared(actual, expected): get_count = lambda k, d : d[k]['count'] if k in d else 0
for k in answerKeys:
E = get_count(k, expected)
O = get_count(k, actual)
if E == 0:
print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O))
elif O == 0:
print(' Warning! O = {}'.format(O))
else:
G += O * log(O/E)
G *= 2
return degreesFreedom, G
def chi_value(actual, expected):
answerKeys = set(list(actual.keys()) + list(expected.keys())) answerKeys = set(list(actual.keys()) + list(expected.keys()))
degreesFreedom = len(answerKeys) degreesFreedom = len(answerKeys)
chiSquared = 0 chiSquared = 0
@ -33,67 +60,21 @@ def chi_squared(actual, expected):
chiSquared += (O - E) ** 2 / E chiSquared += (O - E) ** 2 / E
return degreesFreedom, chiSquared return degreesFreedom, chiSquared
def chi_squared_diff(actual, expected): def dist_test(actual, expected, calculation):
df, chiSquared = chi_squared(actual, expected) df, p = calculation(actual, expected)
return (chiSquared < _chiSquared_table[df]) if df not in _ptable:
raise Exception('{} degrees of freedom does not have a corresponding chi squared value.' + \
' Please look up the value and add it to the table in copycat/statistics.py'.format(df))
return (p < _ptable[df])
def chi_squared_test(actual, expected, show=True): def cross_formula_table(actualDict, expectedDict, calculation):
df, chiSquared = chi_squared(actual, expected)
if chiSquared >= _chiSquared_table[df]:
if show:
print('Significant difference between expected and actual answer distributions: \n' +
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, df))
return False
return True
def cross_formula_chi_squared(actualDict, expectedDict):
failures = []
for ka, actual in actualDict.items():
for ke, expected in expectedDict.items():
print(' Comparing {} with {}: '.format(ka, ke), end='')
if not chi_squared_test(actual, expected, show=False):
failures.append('{}:{}'.format(ka, ke))
print(' Failed.')
else:
print(' Succeeded.')
return failures
def cross_formula_chi_squared_table(actualDict, expectedDict):
data = dict() data = dict()
for ka, actual in actualDict.items(): for ka, actual in actualDict.items():
for ke, expected in expectedDict.items(): for ke, expected in expectedDict.items():
#df, chiSquared = chi_squared(actual, expected) data[(ka, ke)] = dist_test(actual, expected, calculation)
data[(ka, ke)] = chi_squared_diff(actual, expected)
#data[(ka, ke)] = (df, chiSquared)
return data return data
def cross_chi_squared(problemSets): def cross_table(problemSets, calculation=g_value):
failures = defaultdict(list)
for i, (a, problemSetA) in enumerate(problemSets):
for b, problemSetB in problemSets[i + 1:]:
for problemA in problemSetA:
for problemB in problemSetB:
if (problemA.initial == problemB.initial and
problemA.modified == problemB.modified and
problemA.target == problemB.target):
answersA = problemA.distributions
answersB = problemB.distributions
print('-' * 80)
print('\n')
print('{} x {}'.format(a, b))
problemString = '{} x {} for: {}:{}::{}:_\n'.format(a,
b,
problemA.initial,
problemA.modified,
problemA.target)
failures[problemString].append(cross_formula_chi_squared(answersA, answersB))
pprint(answersA)
pprint(answersB)
print('\n')
return failures
def cross_chi_squared_table(problemSets):
table = defaultdict(dict) table = defaultdict(dict)
for i, (a, problemSetA) in enumerate(problemSets): for i, (a, problemSetA) in enumerate(problemSets):
for b, problemSetB in problemSets[i + 1:]: for b, problemSetB in problemSets[i + 1:]:
@ -104,7 +85,11 @@ def cross_chi_squared_table(problemSets):
problemA.target == problemB.target): problemA.target == problemB.target):
answersA = problemA.distributions answersA = problemA.distributions
answersB = problemB.distributions answersB = problemB.distributions
table[(problemA.initial, problemA.modified, problemA.target)][(a, b)] = cross_formula_chi_squared_table(answersA, answersB) table[(problemA.initial,
problemA.modified,
problemA.target)][(a, b)] = (
cross_formula_table(
answersA, answersB, calculation))
return table return table
def iso_chi_squared(actualDict, expectedDict): def iso_chi_squared(actualDict, expectedDict):
@ -112,4 +97,6 @@ def iso_chi_squared(actualDict, expectedDict):
assert key in actualDict, 'The key {} was not tested'.format(key) assert key in actualDict, 'The key {} was not tested'.format(key)
actual = actualDict[key] actual = actualDict[key]
expected = expectedDict[key] expected = expectedDict[key]
chi_squared_test(actual, expected) if not dist_test(actual, expected, g_value):
raise Exception('Value of G higher than expected')

View File

@ -5,7 +5,7 @@ import pickle
from pprint import pprint from pprint import pprint
from copycat import Problem from copycat import Problem
from copycat.statistics import cross_chi_squared, cross_chi_squared_table from copycat.statistics import cross_table
def compare_sets(): def compare_sets():
pass pass
@ -18,9 +18,7 @@ def main(args):
pSet = pickle.load(infile) pSet = pickle.load(infile)
branchProblemSets[filename] = pSet branchProblemSets[filename] = pSet
problemSets.append((filename, pSet)) problemSets.append((filename, pSet))
pprint(problemSets) crossTable = cross_table(problemSets)
pprint(cross_chi_squared(problemSets))
crossTable = cross_chi_squared_table(problemSets)
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0]) key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
tableItems = key_sorted_items(crossTable) tableItems = key_sorted_items(crossTable)