107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import unittest
|
|
import os.path
|
|
import pickle
|
|
import argparse
|
|
import sys
|
|
|
|
from pprint import pprint
|
|
from copycat import Problem
|
|
|
|
# TODO: update test cases to use entropy
|
|
|
|
# CHI2 values for n degrees freedom
|
|
_chiSquared_table = {
|
|
1:3.841,
|
|
2:5.991,
|
|
3:7.815,
|
|
4:9.488,
|
|
5:11.071,
|
|
6:12.592,
|
|
7:14.067,
|
|
8:15.507,
|
|
9:16.919,
|
|
10:18.307
|
|
}
|
|
|
|
class ChiSquaredException(Exception):
|
|
pass
|
|
|
|
def chi_squared(actualDict, expectedDict):
|
|
for key in expectedDict.keys():
|
|
assert key in actualDict, 'The key {} was not tested'.format(key)
|
|
actual = actualDict[key]
|
|
expected = expectedDict[key]
|
|
|
|
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
|
degreesFreedom = len(answerKeys)
|
|
chiSquared = 0
|
|
|
|
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
|
|
|
for k in answerKeys:
|
|
E = get_count(k, expected)
|
|
O = get_count(k, actual)
|
|
if E == 0:
|
|
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
|
else:
|
|
chiSquared += (O - E) ** 2 / E
|
|
|
|
if chiSquared >= _chiSquared_table[degreesFreedom]:
|
|
raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' +
|
|
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
|
|
|
|
def generate():
|
|
print('Generating distributions for new file')
|
|
iterations = 30
|
|
distributions = [
|
|
Problem('abc', 'abd', 'efg', iterations, None),
|
|
Problem('abc', 'abd', 'ijk', iterations, None),
|
|
Problem('abc', 'abd', 'xyz', iterations, None),
|
|
Problem('abc', 'abd', 'ijkk', iterations, None),
|
|
Problem('abc', 'abd', 'mrrjjj', iterations, None)]
|
|
|
|
for distribution in distributions:
|
|
distribution.generate()
|
|
|
|
with open(TestCopycat.Filename, 'wb') as outfile:
|
|
pickle.dump(distributions, outfile)
|
|
return distributions
|
|
|
|
class TestCopycat(unittest.TestCase):
|
|
|
|
Filename = None
|
|
|
|
def setUp(self):
|
|
self.longMessage = True # new in Python 2.7
|
|
|
|
def test(self):
|
|
print('Testing copycat with input file: {}'.format(TestCopycat.Filename))
|
|
try:
|
|
with open(TestCopycat.Filename, 'rb') as infile:
|
|
distributions = pickle.load(infile)
|
|
except Exception as e:
|
|
print('Generating due to error:')
|
|
print(e)
|
|
distributions = generate()
|
|
|
|
for distribution in distributions:
|
|
distribution.test(chi_squared)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--generate', action='store_true')
|
|
parser.add_argument('filename', default='.distributions', nargs='?')
|
|
parser.add_argument('unittest_args', default=[], nargs='?')
|
|
|
|
args = parser.parse_args()
|
|
# TODO: Go do something with args.input and args.filename
|
|
|
|
TestCopycat.Filename = args.filename
|
|
|
|
if args.generate:
|
|
generate()
|
|
|
|
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
|
|
sys.argv[1:] = args.unittest_args
|
|
unittest.main()
|