diff --git a/copycat/__init__.py b/copycat/__init__.py index 67e5cc9..292c4b6 100644 --- a/copycat/__init__.py +++ b/copycat/__init__.py @@ -1 +1,2 @@ from .copycat import Copycat, Reporter # noqa +from .problem import Problem diff --git a/copycat/problem.py b/copycat/problem.py new file mode 100644 index 0000000..5805eca --- /dev/null +++ b/copycat/problem.py @@ -0,0 +1,58 @@ +from .copycat import Copycat + +from pprint import pprint + +class Problem: + def __init__(self, initial, modified, target, iterations, distributions=None, formulas=None): + self.formulas = formulas + self.initial = initial + self.modified = modified + self.target = target + + self.iterations = iterations + if distributions is None: + self.distributions = self.solve() + else: + self.distributions = distributions + if formulas is not None: + assert hasattr(Copycat().workspace, 'temperature') + + def test(self, comparison, expected=None): + print('-' * 120) + print('Testing copycat problem: {} : {} :: {} : _'.format(self.initial, + self.modified, + self.target)) + print('expected:') + if expected is None: + expected = self.distributions + pprint(expected) + + actual = self.solve() + print('actual:') + pprint(actual) + comparison(actual, expected) + print('-' * 120) + + def solve(self): + copycat = Copycat() + answers = dict() + if self.formulas == None: + if hasattr(copycat.workspace, 'temperature'): + formula = copycat.workspace.temperature.getAdj() + else: + formula = None + answers[formula] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + else: + for formula in self.formulas: + copycat.temperature.useAdj(formula) + answers[formulas] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + return answers + + def generate(self): + self.distributions = self.solve() diff --git a/copycat/statistics.py b/copycat/statistics.py new file mode 100644 index 0000000..4f1ffe3 --- /dev/null +++ b/copycat/statistics.py @@ -0,0 +1,57 @@ +# CHI2 values for n degrees freedom +_chiSquared_table = { + 1:3.841, + 2:5.991, + 3:7.815, + 4:9.488, + 5:11.071, + 6:12.592, + 7:14.067, + 8:15.507, + 9:16.919, + 10:18.307 + } + +class ChiSquaredException(Exception): + pass + +def chi_squared(actual, expected): + answerKeys = set(list(actual.keys()) + list(expected.keys())) + degreesFreedom = len(answerKeys) + chiSquared = 0 + + get_count = lambda k, d : d[k]['count'] if k in d else 0 + + for k in answerKeys: + E = get_count(k, expected) + O = get_count(k, actual) + if E == 0: + print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) + else: + chiSquared += (O - E) ** 2 / E + return chiSquared + +def cross_formula_chi_squared(actualDict, expectedDict): + for ka, actual in actualDict.items(): + for ke, expected in expectedDict.items(): + print('Comparing {} with {}'.format(ka, ke)) + chiSquared = chi_squared(actual, expected) + + if chiSquared >= _chiSquared_table[degreesFreedom]: + print('Significant difference between expected and actual answer distributions: \n' + + 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) + +def cross_chi_squared(problemSets): + for i, problemSetA in enumerate(problemSets): + for problemSetB in problemSets[i + 1:]: + for problemA in problemSetA: + for problemB in problemSetB: + answersA = problemA.distributions + answersB = problemB.distributions + cross_formula_chi_squared(answersA, answersB) + +def iso_chi_squared(actualDict, expectedDict): + for key in expectedDict.keys(): + assert key in actualDict, 'The key {} was not tested'.format(key) + actual = actualDict[key] + expected = expectedDict[key] diff --git a/tests.py b/tests.py new file mode 100644 index 0000000..7842590 --- /dev/null +++ b/tests.py @@ -0,0 +1,62 @@ +import unittest +import os.path +import pickle +import argparse +import sys + +from pprint import pprint +from copycat import Problem +from copycat.statistics import iso_chi_squared + +# TODO: update test cases to use entropy + +def generate(): + print('Generating distributions for new file') + iterations = 30 + problems = [ + Problem('abc', 'abd', 'efg', iterations), + Problem('abc', 'abd', 'ijk', iterations), + Problem('abc', 'abd', 'xyz', iterations), + Problem('abc', 'abd', 'ijkk', iterations), + Problem('abc', 'abd', 'mrrjjj', iterations)] + + with open(TestCopycat.Filename, 'wb') as outfile: + pickle.dump(problems, outfile) + return problems + +class TestCopycat(unittest.TestCase): + Filename = None + + def setUp(self): + self.longMessage = True # new in Python 2.7 + + def test(self): + print('Testing copycat with input file: {}'.format(TestCopycat.Filename)) + try: + with open(TestCopycat.Filename, 'rb') as infile: + problems = pickle.load(infile) + except Exception as e: + print('Generating due to error:') + print(e) + problems = generate() + + for problem in problems: + problem.test(iso_chi_squared) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--generate', action='store_true') + parser.add_argument('filename', default='.distributions', nargs='?') + parser.add_argument('unittest_args', default=[], nargs='?') + + args = parser.parse_args() + # TODO: Go do something with args.input and args.filename + + TestCopycat.Filename = args.filename + + if args.generate: + generate() + + # Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone) + sys.argv[1:] = args.unittest_args + unittest.main()