diff --git a/.distributions b/.distributions index 9931287..b5ea0d0 100644 Binary files a/.distributions and b/.distributions differ diff --git a/copycat/__init__.py b/copycat/__init__.py index 4e8bc55..732e307 100644 --- a/copycat/__init__.py +++ b/copycat/__init__.py @@ -1,3 +1,4 @@ from .copycat import Copycat, Reporter # noqa +from .problem import Problem from .plot import plot_answers from .io import save_answers diff --git a/copycat/copycat.py b/copycat/copycat.py index 501a3d6..37a3dd7 100644 --- a/copycat/copycat.py +++ b/copycat/copycat.py @@ -89,8 +89,6 @@ class Copycat(object): d['avgtime'] = d.pop('sumtime') / d['count'] print('The formula {} provided:'.format(formula)) print('Average difference: {}'.format(self.temperature.getAverageDifference())) - pprint(answers) - return answers def run_forever(self, initial, modified, target): diff --git a/copycat/problem.py b/copycat/problem.py new file mode 100644 index 0000000..5a18617 --- /dev/null +++ b/copycat/problem.py @@ -0,0 +1,53 @@ +from .copycat import Copycat + +from pprint import pprint + +class Problem: + def __init__(self, initial, modified, target, iterations, distributions, formulas=None): + self.initial = initial + self.modified = modified + self.target = target + + self.iterations = iterations + self.distributions = distributions + self.formulas = formulas + if formulas is not None: + assert hasattr(Copycat().workspace, 'temperature') + + def test(self, comparison): + print('-' * 120) + print('Testing copycat problem: {} : {} :: {} : _'.format(self.initial, + self.modified, + self.target)) + print('expected:') + pprint(self.distributions) + + actual = self.solve() + print('actual:') + pprint(actual) + comparison(actual, self.distributions) + print('-' * 120) + + def solve(self): + copycat = Copycat() + answers = dict() + if self.formulas == None: + if hasattr(copycat.workspace, 'temperature'): + formula = copycat.workspace.temperature.getAdj() + else: + formula = None + answers[formula] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + else: + for formula in self.formulas: + copycat.temperature.useAdj(formula) + answers[formulas] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + return answers + + def generate(self): + self.distributions = self.solve() diff --git a/copycat/temperature.py b/copycat/temperature.py index ce78d96..6551571 100644 --- a/copycat/temperature.py +++ b/copycat/temperature.py @@ -147,5 +147,8 @@ class Temperature(object): print('Changing to adjustment formula {}'.format(adj)) self.adjustmentType = adj + def getAdj(self): + return self.adjustmentType + def adj_formulas(self): return self._adjustmentFormulas.keys() diff --git a/tests.py b/tests.py index 81ec806..3f633f4 100644 --- a/tests.py +++ b/tests.py @@ -5,7 +5,7 @@ import argparse import sys from pprint import pprint -from copycat import Copycat +from copycat import Problem # TODO: update test cases to use entropy @@ -26,59 +26,39 @@ _chiSquared_table = { class ChiSquaredException(Exception): pass -def chi_squared(actual, expected): - answerKeys = set(list(actual.keys()) + list(expected.keys())) - degreesFreedom = len(answerKeys) - chiSquared = 0 +def chi_squared(actualDict, expectedDict): + for key in expectedDict.keys(): + assert key in actualDict, 'The key {} was not tested'.format(key) + actual = actualDict[key] + expected = expectedDict[key] - get_count = lambda k, d : d[k]['count'] if k in d else 0 + answerKeys = set(list(actual.keys()) + list(expected.keys())) + degreesFreedom = len(answerKeys) + chiSquared = 0 - for k in answerKeys: - E = get_count(k, expected) - O = get_count(k, actual) - if E == 0: - print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) - else: - chiSquared += (O - E) ** 2 / E + get_count = lambda k, d : d[k]['count'] if k in d else 0 - if chiSquared >= _chiSquared_table[degreesFreedom]: - raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' + - 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) + for k in answerKeys: + E = get_count(k, expected) + O = get_count(k, actual) + if E == 0: + print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) + else: + chiSquared += (O - E) ** 2 / E -class AnswerDistribution: - def __init__(self, initial, modified, target, iterations, distribution): - self.initial = initial - self.modified = modified - self.target = target - self.iterations = iterations - self.distribution = distribution + if chiSquared >= _chiSquared_table[degreesFreedom]: + raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' + + 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) - def test(self): - print('expected:') - pprint(self.distribution) - actual = Copycat().run(self.initial, - self.modified, - self.target, - self.iterations) - print('actual:') - pprint(actual) - chi_squared(actual, self.distribution) - - def generate(self): - self.distribution = Copycat().run(self.initial, - self.modified, - self.target, - self.iterations) - -def generate(self): +def generate(): print('Generating distributions for new file') iterations = 30 distributions = [ - AnswerDistribution('abc', 'abd', 'efg', iterations, None), - AnswerDistribution('abc', 'abd', 'ijk', iterations, None), - AnswerDistribution('abc', 'abd', 'xyz', iterations, None), - AnswerDistribution('abc', 'abd', 'ijkk', iterations, None), - AnswerDistribution('abc', 'abd', 'mrrjjj', iterations, None)] + Problem('abc', 'abd', 'efg', iterations, None), + Problem('abc', 'abd', 'ijk', iterations, None), + Problem('abc', 'abd', 'xyz', iterations, None), + Problem('abc', 'abd', 'ijkk', iterations, None), + Problem('abc', 'abd', 'mrrjjj', iterations, None)] for distribution in distributions: distribution.generate() @@ -95,7 +75,7 @@ class TestCopycat(unittest.TestCase): self.longMessage = True # new in Python 2.7 def test(self): - print(TestCopycat.Filename) + print('Testing copycat with input file: {}'.format(TestCopycat.Filename)) try: with open(TestCopycat.Filename, 'rb') as infile: distributions = pickle.load(infile) @@ -105,7 +85,7 @@ class TestCopycat(unittest.TestCase): distributions = generate() for distribution in distributions: - distribution.test() + distribution.test(chi_squared) if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -116,11 +96,11 @@ if __name__ == '__main__': args = parser.parse_args() # TODO: Go do something with args.input and args.filename + TestCopycat.Filename = args.filename + if args.generate: generate() - TestCopycat.Filename = args.filename - # Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone) sys.argv[1:] = args.unittest_args unittest.main()