Adds probability difference to copycat/statistics

This commit is contained in:
LSaldyt
2018-01-12 15:32:53 -07:00
parent 1506e673a2
commit 84cc3bf9c1
2 changed files with 39 additions and 13 deletions

View File

@ -24,17 +24,18 @@ _ptable = {
16:26.3 16:26.3
} }
_get_count = lambda k, d : d[k]['count'] if k in d else 0
def g_value(actual, expected): def g_value(actual, expected):
# G = 2 * sum(Oi * ln(Oi/Ei)) # G = 2 * sum(Oi * ln(Oi/Ei))
answerKeys = set(list(actual.keys()) + list(expected.keys())) answerKeys = set(list(actual.keys()) + list(expected.keys()))
degreesFreedom = len(answerKeys) degreesFreedom = len(answerKeys)
G = 0 G = 0
get_count = lambda k, d : d[k]['count'] if k in d else 0
for k in answerKeys: for k in answerKeys:
E = get_count(k, expected) E = _get_count(k, expected)
O = get_count(k, actual) O = _get_count(k, actual)
if E == 0: if E == 0:
print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O)) print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O))
elif O == 0: elif O == 0:
@ -49,17 +50,39 @@ def chi_value(actual, expected):
degreesFreedom = len(answerKeys) degreesFreedom = len(answerKeys)
chiSquared = 0 chiSquared = 0
get_count = lambda k, d : d[k]['count'] if k in d else 0
for k in answerKeys: for k in answerKeys:
E = get_count(k, expected) E = _get_count(k, expected)
O = get_count(k, actual) O = _get_count(k, actual)
if E == 0: if E == 0:
print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O)) print(' Warning! Expected 0 counts of {}, but got {}'.format(k, O))
else: else:
chiSquared += (O - E) ** 2 / E chiSquared += (O - E) ** 2 / E
return degreesFreedom, chiSquared return degreesFreedom, chiSquared
def probability_difference(actual, expected):
actualC = 0
expectedC = 0
for k in set(list(actual.keys()) + list(expected.keys())):
expectedC += _get_count(k, expected)
actualC += _get_count(k, actual)
p = 0
Et = 0
Ot = 0
for k in set(list(actual.keys()) + list(expected.keys())):
E = _get_count(k, expected)
O = _get_count(k, actual)
Ep = E / expectedC
Op = O / actualC
p += abs(Ep - Op)
p /= 2 # P is between 0 and 2 -> P is between 0 and 1
return p
def dist_test(actual, expected, calculation): def dist_test(actual, expected, calculation):
df, p = calculation(actual, expected) df, p = calculation(actual, expected)
if df not in _ptable: if df not in _ptable:
@ -67,14 +90,17 @@ def dist_test(actual, expected, calculation):
' Please look up the value and add it to the table in copycat/statistics.py'.format(df)) ' Please look up the value and add it to the table in copycat/statistics.py'.format(df))
return (p < _ptable[df]) return (p < _ptable[df])
def cross_formula_table(actualDict, expectedDict, calculation): def cross_formula_table(actualDict, expectedDict, calculation, probs=False):
data = dict() data = dict()
for ka, actual in actualDict.items(): for ka, actual in actualDict.items():
for ke, expected in expectedDict.items(): for ke, expected in expectedDict.items():
if probs:
data[(ka, ke)] = probability_difference(actual, expected)
else:
data[(ka, ke)] = dist_test(actual, expected, calculation) data[(ka, ke)] = dist_test(actual, expected, calculation)
return data return data
def cross_table(problemSets, calculation=g_value): def cross_table(problemSets, calculation=g_value, probs=False):
table = defaultdict(dict) table = defaultdict(dict)
for i, (a, problemSetA) in enumerate(problemSets): for i, (a, problemSetA) in enumerate(problemSets):
for b, problemSetB in problemSets[i + 1:]: for b, problemSetB in problemSets[i + 1:]:
@ -89,7 +115,7 @@ def cross_table(problemSets, calculation=g_value):
problemA.modified, problemA.modified,
problemA.target)][(a, b)] = ( problemA.target)][(a, b)] = (
cross_formula_table( cross_formula_table(
answersA, answersB, calculation)) answersA, answersB, calculation, probs))
return table return table
def iso_chi_squared(actualDict, expectedDict): def iso_chi_squared(actualDict, expectedDict):

View File

@ -19,7 +19,7 @@ def main(args):
pSet = pickle.load(infile) pSet = pickle.load(infile)
branchProblemSets[filename] = pSet branchProblemSets[filename] = pSet
problemSets.append((filename, pSet)) problemSets.append((filename, pSet))
crossTable = cross_table(problemSets) crossTable = cross_table(problemSets, probs=True)
key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0]) key_sorted_items = lambda d : sorted(d.items(), key=lambda t:t[0])
tableItems = key_sorted_items(crossTable) tableItems = key_sorted_items(crossTable)