Adds problem class
This commit is contained in:
80
tests.py
80
tests.py
@ -5,7 +5,7 @@ import argparse
|
||||
import sys
|
||||
|
||||
from pprint import pprint
|
||||
from copycat import Copycat
|
||||
from copycat import Problem
|
||||
|
||||
# TODO: update test cases to use entropy
|
||||
|
||||
@ -26,59 +26,39 @@ _chiSquared_table = {
|
||||
class ChiSquaredException(Exception):
|
||||
pass
|
||||
|
||||
def chi_squared(actual, expected):
|
||||
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
||||
degreesFreedom = len(answerKeys)
|
||||
chiSquared = 0
|
||||
def chi_squared(actualDict, expectedDict):
|
||||
for key in expectedDict.keys():
|
||||
assert key in actualDict, 'The key {} was not tested'.format(key)
|
||||
actual = actualDict[key]
|
||||
expected = expectedDict[key]
|
||||
|
||||
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
||||
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
||||
degreesFreedom = len(answerKeys)
|
||||
chiSquared = 0
|
||||
|
||||
for k in answerKeys:
|
||||
E = get_count(k, expected)
|
||||
O = get_count(k, actual)
|
||||
if E == 0:
|
||||
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
||||
else:
|
||||
chiSquared += (O - E) ** 2 / E
|
||||
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
||||
|
||||
if chiSquared >= _chiSquared_table[degreesFreedom]:
|
||||
raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' +
|
||||
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
|
||||
for k in answerKeys:
|
||||
E = get_count(k, expected)
|
||||
O = get_count(k, actual)
|
||||
if E == 0:
|
||||
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
||||
else:
|
||||
chiSquared += (O - E) ** 2 / E
|
||||
|
||||
class AnswerDistribution:
|
||||
def __init__(self, initial, modified, target, iterations, distribution):
|
||||
self.initial = initial
|
||||
self.modified = modified
|
||||
self.target = target
|
||||
self.iterations = iterations
|
||||
self.distribution = distribution
|
||||
if chiSquared >= _chiSquared_table[degreesFreedom]:
|
||||
raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' +
|
||||
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
|
||||
|
||||
def test(self):
|
||||
print('expected:')
|
||||
pprint(self.distribution)
|
||||
actual = Copycat().run(self.initial,
|
||||
self.modified,
|
||||
self.target,
|
||||
self.iterations)
|
||||
print('actual:')
|
||||
pprint(actual)
|
||||
chi_squared(actual, self.distribution)
|
||||
|
||||
def generate(self):
|
||||
self.distribution = Copycat().run(self.initial,
|
||||
self.modified,
|
||||
self.target,
|
||||
self.iterations)
|
||||
|
||||
def generate(self):
|
||||
def generate():
|
||||
print('Generating distributions for new file')
|
||||
iterations = 30
|
||||
distributions = [
|
||||
AnswerDistribution('abc', 'abd', 'efg', iterations, None),
|
||||
AnswerDistribution('abc', 'abd', 'ijk', iterations, None),
|
||||
AnswerDistribution('abc', 'abd', 'xyz', iterations, None),
|
||||
AnswerDistribution('abc', 'abd', 'ijkk', iterations, None),
|
||||
AnswerDistribution('abc', 'abd', 'mrrjjj', iterations, None)]
|
||||
Problem('abc', 'abd', 'efg', iterations, None),
|
||||
Problem('abc', 'abd', 'ijk', iterations, None),
|
||||
Problem('abc', 'abd', 'xyz', iterations, None),
|
||||
Problem('abc', 'abd', 'ijkk', iterations, None),
|
||||
Problem('abc', 'abd', 'mrrjjj', iterations, None)]
|
||||
|
||||
for distribution in distributions:
|
||||
distribution.generate()
|
||||
@ -95,7 +75,7 @@ class TestCopycat(unittest.TestCase):
|
||||
self.longMessage = True # new in Python 2.7
|
||||
|
||||
def test(self):
|
||||
print(TestCopycat.Filename)
|
||||
print('Testing copycat with input file: {}'.format(TestCopycat.Filename))
|
||||
try:
|
||||
with open(TestCopycat.Filename, 'rb') as infile:
|
||||
distributions = pickle.load(infile)
|
||||
@ -105,7 +85,7 @@ class TestCopycat(unittest.TestCase):
|
||||
distributions = generate()
|
||||
|
||||
for distribution in distributions:
|
||||
distribution.test()
|
||||
distribution.test(chi_squared)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -116,11 +96,11 @@ if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
# TODO: Go do something with args.input and args.filename
|
||||
|
||||
TestCopycat.Filename = args.filename
|
||||
|
||||
if args.generate:
|
||||
generate()
|
||||
|
||||
TestCopycat.Filename = args.filename
|
||||
|
||||
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
|
||||
sys.argv[1:] = args.unittest_args
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user