From d16e347f04dafab9c53ac79186c8e57db63ac37f Mon Sep 17 00:00:00 2001 From: LSaldyt Date: Tue, 14 Nov 2017 18:05:58 -0700 Subject: [PATCH] Adds problem class --- .distributions | Bin 833 -> 920 bytes copycat/__init__.py | 1 + copycat/copycat.py | 2 -- copycat/problem.py | 53 +++++++++++++++++++++++++++ copycat/temperature.py | 3 ++ tests.py | 80 ++++++++++++++++------------------------- 6 files changed, 87 insertions(+), 52 deletions(-) create mode 100644 copycat/problem.py diff --git a/.distributions b/.distributions index 9931287db9508cc7a6f5326910d1391a29dc05e9..b5ea0d0562e7efc76389266eb442bb81f70550d5 100644 GIT binary patch literal 920 zcmZvaOK1~O6o#iJ=`*dGG*)ZcTJ2-n#Bh03SNoLY}CzHpS%=BTo z(qd7h)FPBFT#0lei0D%2qSA#6>B3zRaVH9*xKnWBxp$IcrL&lsd-%_PzW;I-JQE2#qumsag)(JG`&Jy(CT{DrE%ML5L3BP+F%B;^Vk%}`wfuj0X@n$p0hRO6qmT2s0W z)uOCtrl3Y?#Ii<9wnbBpO$+%F)H0|WXHYMfRT?xf@XJaKx0;@^X-aUo{sH0WZ`Exey~tF80vfeOSR%tY%J zJTAlxHuo^-6~YyHA4DYIHMz05b|8ezpsxbq#V)*wEcA1V0i=k?id!^IILja^+}DF) zcpJ75Bfuk$KL7q;=j)b3`WXxg!M)vm<8Lf}c-hnIbIMj(6NVVXgmQJKW$e8@fmRLN zD*od4ALd^^5OkQqNClKHYR<9@!+>)f^*o|pkj4IDIg1G{G8h#THlA+#hqnDVA;c$) zg`d26_~n%&_Qka@;1Yw&Li~7Ogg+?uaq(a_&@nn}I2B_9x8w?gt3vx|sS?u5wbwhp NEnd4b`9HcI{R=wS3R(aF literal 833 zcmZXS%WD%+6o)5G`Y=v4iM6k^F>1Bbhqd(yQlVn8N>(m}*+@t-)7(3gne@!ewHf7t zQjv;Ku>}`VcSTpyjRcBH-3a;*bW=fe>q3_J`nQN!$yK3jxlHOhWrPj_;h zccCSdQe`&Mc37a5LYvaqL9mBHR8d114n8``@_krgs5$7YLzuSi**fR zVlkK3ATHF{#vaaF_o4ECg*}b`Ze~C26QVE7Bj0ke&?G660`!Xjl zyyQ8+EkFSgTma`UwR<=X2<*U4Xn31=vr73}jt9)hQ66U}NJYi0Z0&J_CCA0@c! A{r~^~ diff --git a/copycat/__init__.py b/copycat/__init__.py index 4e8bc55..732e307 100644 --- a/copycat/__init__.py +++ b/copycat/__init__.py @@ -1,3 +1,4 @@ from .copycat import Copycat, Reporter # noqa +from .problem import Problem from .plot import plot_answers from .io import save_answers diff --git a/copycat/copycat.py b/copycat/copycat.py index 501a3d6..37a3dd7 100644 --- a/copycat/copycat.py +++ b/copycat/copycat.py @@ -89,8 +89,6 @@ class Copycat(object): d['avgtime'] = d.pop('sumtime') / d['count'] print('The formula {} provided:'.format(formula)) print('Average difference: {}'.format(self.temperature.getAverageDifference())) - pprint(answers) - return answers def run_forever(self, initial, modified, target): diff --git a/copycat/problem.py b/copycat/problem.py new file mode 100644 index 0000000..5a18617 --- /dev/null +++ b/copycat/problem.py @@ -0,0 +1,53 @@ +from .copycat import Copycat + +from pprint import pprint + +class Problem: + def __init__(self, initial, modified, target, iterations, distributions, formulas=None): + self.initial = initial + self.modified = modified + self.target = target + + self.iterations = iterations + self.distributions = distributions + self.formulas = formulas + if formulas is not None: + assert hasattr(Copycat().workspace, 'temperature') + + def test(self, comparison): + print('-' * 120) + print('Testing copycat problem: {} : {} :: {} : _'.format(self.initial, + self.modified, + self.target)) + print('expected:') + pprint(self.distributions) + + actual = self.solve() + print('actual:') + pprint(actual) + comparison(actual, self.distributions) + print('-' * 120) + + def solve(self): + copycat = Copycat() + answers = dict() + if self.formulas == None: + if hasattr(copycat.workspace, 'temperature'): + formula = copycat.workspace.temperature.getAdj() + else: + formula = None + answers[formula] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + else: + for formula in self.formulas: + copycat.temperature.useAdj(formula) + answers[formulas] = copycat.run(self.initial, + self.modified, + self.target, + self.iterations) + return answers + + def generate(self): + self.distributions = self.solve() diff --git a/copycat/temperature.py b/copycat/temperature.py index ce78d96..6551571 100644 --- a/copycat/temperature.py +++ b/copycat/temperature.py @@ -147,5 +147,8 @@ class Temperature(object): print('Changing to adjustment formula {}'.format(adj)) self.adjustmentType = adj + def getAdj(self): + return self.adjustmentType + def adj_formulas(self): return self._adjustmentFormulas.keys() diff --git a/tests.py b/tests.py index 81ec806..3f633f4 100644 --- a/tests.py +++ b/tests.py @@ -5,7 +5,7 @@ import argparse import sys from pprint import pprint -from copycat import Copycat +from copycat import Problem # TODO: update test cases to use entropy @@ -26,59 +26,39 @@ _chiSquared_table = { class ChiSquaredException(Exception): pass -def chi_squared(actual, expected): - answerKeys = set(list(actual.keys()) + list(expected.keys())) - degreesFreedom = len(answerKeys) - chiSquared = 0 +def chi_squared(actualDict, expectedDict): + for key in expectedDict.keys(): + assert key in actualDict, 'The key {} was not tested'.format(key) + actual = actualDict[key] + expected = expectedDict[key] - get_count = lambda k, d : d[k]['count'] if k in d else 0 + answerKeys = set(list(actual.keys()) + list(expected.keys())) + degreesFreedom = len(answerKeys) + chiSquared = 0 - for k in answerKeys: - E = get_count(k, expected) - O = get_count(k, actual) - if E == 0: - print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) - else: - chiSquared += (O - E) ** 2 / E + get_count = lambda k, d : d[k]['count'] if k in d else 0 - if chiSquared >= _chiSquared_table[degreesFreedom]: - raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' + - 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) + for k in answerKeys: + E = get_count(k, expected) + O = get_count(k, actual) + if E == 0: + print('Warning! Expected 0 counts of {}, but got {}'.format(k, O)) + else: + chiSquared += (O - E) ** 2 / E -class AnswerDistribution: - def __init__(self, initial, modified, target, iterations, distribution): - self.initial = initial - self.modified = modified - self.target = target - self.iterations = iterations - self.distribution = distribution + if chiSquared >= _chiSquared_table[degreesFreedom]: + raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' + + 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) - def test(self): - print('expected:') - pprint(self.distribution) - actual = Copycat().run(self.initial, - self.modified, - self.target, - self.iterations) - print('actual:') - pprint(actual) - chi_squared(actual, self.distribution) - - def generate(self): - self.distribution = Copycat().run(self.initial, - self.modified, - self.target, - self.iterations) - -def generate(self): +def generate(): print('Generating distributions for new file') iterations = 30 distributions = [ - AnswerDistribution('abc', 'abd', 'efg', iterations, None), - AnswerDistribution('abc', 'abd', 'ijk', iterations, None), - AnswerDistribution('abc', 'abd', 'xyz', iterations, None), - AnswerDistribution('abc', 'abd', 'ijkk', iterations, None), - AnswerDistribution('abc', 'abd', 'mrrjjj', iterations, None)] + Problem('abc', 'abd', 'efg', iterations, None), + Problem('abc', 'abd', 'ijk', iterations, None), + Problem('abc', 'abd', 'xyz', iterations, None), + Problem('abc', 'abd', 'ijkk', iterations, None), + Problem('abc', 'abd', 'mrrjjj', iterations, None)] for distribution in distributions: distribution.generate() @@ -95,7 +75,7 @@ class TestCopycat(unittest.TestCase): self.longMessage = True # new in Python 2.7 def test(self): - print(TestCopycat.Filename) + print('Testing copycat with input file: {}'.format(TestCopycat.Filename)) try: with open(TestCopycat.Filename, 'rb') as infile: distributions = pickle.load(infile) @@ -105,7 +85,7 @@ class TestCopycat(unittest.TestCase): distributions = generate() for distribution in distributions: - distribution.test() + distribution.test(chi_squared) if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -116,11 +96,11 @@ if __name__ == '__main__': args = parser.parse_args() # TODO: Go do something with args.input and args.filename + TestCopycat.Filename = args.filename + if args.generate: generate() - TestCopycat.Filename = args.filename - # Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone) sys.argv[1:] = args.unittest_args unittest.main()