Adds distribution saving and cl args
This commit is contained in:
BIN
.distributions
Normal file
BIN
.distributions
Normal file
Binary file not shown.
120
tests.py
120
tests.py
@ -1,6 +1,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from pprint import pprint
|
import os.path
|
||||||
|
import pickle
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
from copycat import Copycat
|
from copycat import Copycat
|
||||||
|
|
||||||
# TODO: update test cases to use entropy
|
# TODO: update test cases to use entropy
|
||||||
@ -19,38 +23,94 @@ _chiSquared_table = {
|
|||||||
10:18.307
|
10:18.307
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ChiSquaredException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def chi_squared(actual, expected):
|
||||||
|
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
||||||
|
degreesFreedom = len(answerKeys)
|
||||||
|
chiSquared = 0
|
||||||
|
|
||||||
|
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
||||||
|
|
||||||
|
for k in answerKeys:
|
||||||
|
E = get_count(k, expected)
|
||||||
|
O = get_count(k, actual)
|
||||||
|
if E == 0:
|
||||||
|
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
||||||
|
else:
|
||||||
|
chiSquared += (O - E) ** 2 / E
|
||||||
|
|
||||||
|
if chiSquared >= _chiSquared_table[degreesFreedom]:
|
||||||
|
raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' +
|
||||||
|
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
|
||||||
|
|
||||||
|
class AnswerDistribution:
|
||||||
|
def __init__(self, initial, modified, target, iterations, distribution):
|
||||||
|
self.initial = initial
|
||||||
|
self.modified = modified
|
||||||
|
self.target = target
|
||||||
|
self.iterations = iterations
|
||||||
|
self.distribution = distribution
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
print('expected:')
|
||||||
|
pprint(self.distribution)
|
||||||
|
actual = Copycat().run(self.initial,
|
||||||
|
self.modified,
|
||||||
|
self.target,
|
||||||
|
self.iterations)
|
||||||
|
print('actual:')
|
||||||
|
pprint(actual)
|
||||||
|
chi_squared(actual, self.distribution)
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
self.distribution = Copycat().run(self.initial,
|
||||||
|
self.modified,
|
||||||
|
self.target,
|
||||||
|
self.iterations)
|
||||||
|
|
||||||
|
|
||||||
class TestCopycat(unittest.TestCase):
|
class TestCopycat(unittest.TestCase):
|
||||||
|
|
||||||
|
Filename = '.distributions'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.longMessage = True # new in Python 2.7
|
self.longMessage = True # new in Python 2.7
|
||||||
|
|
||||||
def assertProbabilitiesLookRoughlyLike(self, actual, expected, iterations):
|
def generate(self):
|
||||||
answerKeys = set(list(actual.keys()) + list(expected.keys()))
|
print('Generating distributions for new file')
|
||||||
degreesFreedom = len(answerKeys)
|
iterations = 30
|
||||||
chiSquared = 0
|
distributions = [
|
||||||
|
AnswerDistribution('abc', 'abd', 'efg', iterations, None),
|
||||||
|
AnswerDistribution('abc', 'abd', 'ijk', iterations, None),
|
||||||
|
AnswerDistribution('abc', 'abd', 'xyz', iterations, None),
|
||||||
|
AnswerDistribution('abc', 'abd', 'ijkk', iterations, None),
|
||||||
|
AnswerDistribution('abc', 'abd', 'mrrjjj', iterations, None)]
|
||||||
|
|
||||||
get_count = lambda k, d : d[k]['count'] if k in d else 0
|
for distribution in distributions:
|
||||||
|
distribution.generate()
|
||||||
|
|
||||||
for k in answerKeys:
|
with open(TestCopycat.Filename, 'wb') as outfile:
|
||||||
E = get_count(k, expected)
|
pickle.dump(distributions, outfile)
|
||||||
O = get_count(k, actual)
|
return distributions
|
||||||
if E == 0:
|
|
||||||
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
|
|
||||||
else:
|
|
||||||
chiSquared += (O - E) ** 2 / E
|
|
||||||
|
|
||||||
if chiSquared >= _chiSquared_table[degreesFreedom]:
|
def test(self):
|
||||||
self.fail('Significant difference between expected and actual answer distributions: \n' +
|
try:
|
||||||
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
|
with open(TestCopycat.Filename, 'rb') as infile:
|
||||||
|
distributions = pickle.load(infile)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print('Generating due to error..')
|
||||||
|
distributions = self.generate()
|
||||||
|
|
||||||
|
for distribution in distributions:
|
||||||
|
distribution.test()
|
||||||
|
|
||||||
|
'''
|
||||||
def run_testcase(self, initial, modified, target, iterations, expected):
|
def run_testcase(self, initial, modified, target, iterations, expected):
|
||||||
print('expected:')
|
adist = AnswerDistribution(initial, modified, target, iterations, expected)
|
||||||
pprint(expected)
|
adist.test()
|
||||||
actual = Copycat().run(initial, modified, target, iterations)
|
|
||||||
print('actual:')
|
|
||||||
pprint(actual)
|
|
||||||
self.assertEqual(sum(a['count'] for a in list(actual.values())), iterations)
|
|
||||||
self.assertProbabilitiesLookRoughlyLike(actual, expected, iterations)
|
|
||||||
|
|
||||||
def test_simple_cases(self):
|
def test_simple_cases(self):
|
||||||
self.run_testcase('abc', 'abd', 'efg', 30,
|
self.run_testcase('abc', 'abd', 'efg', 30,
|
||||||
@ -102,8 +162,7 @@ class TestCopycat(unittest.TestCase):
|
|||||||
'count': 11},
|
'count': 11},
|
||||||
'mrrkkk': {'avgtemp': 43.709349775080746, 'avgtime': 1376.2, 'count': 10}})
|
'mrrkkk': {'avgtemp': 43.709349775080746, 'avgtime': 1376.2, 'count': 10}})
|
||||||
|
|
||||||
'''
|
# Below are examples of improvements that could be made to copycat.
|
||||||
Below are examples of improvements that could be made to copycat.
|
|
||||||
|
|
||||||
def test_elongation(self):
|
def test_elongation(self):
|
||||||
# This isn't remotely what a human would say.
|
# This isn't remotely what a human would say.
|
||||||
@ -137,6 +196,15 @@ class TestCopycat(unittest.TestCase):
|
|||||||
})
|
})
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--generate', action='store_true')
|
||||||
|
parser.add_argument('filename', default='.distributions', nargs='?')
|
||||||
|
parser.add_argument('unittest_args', default=[], nargs='?')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
# TODO: Go do something with args.input and args.filename
|
||||||
|
|
||||||
|
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
|
||||||
|
sys.argv[1:] = args.unittest_args
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user