Adds problem class

This commit is contained in:
LSaldyt
2017-11-14 18:05:58 -07:00
parent 20d754faa7
commit d16e347f04
6 changed files with 87 additions and 52 deletions

Binary file not shown.

View File

@ -1,3 +1,4 @@
from .copycat import Copycat, Reporter # noqa from .copycat import Copycat, Reporter # noqa
from .problem import Problem
from .plot import plot_answers from .plot import plot_answers
from .io import save_answers from .io import save_answers

View File

@ -89,8 +89,6 @@ class Copycat(object):
d['avgtime'] = d.pop('sumtime') / d['count'] d['avgtime'] = d.pop('sumtime') / d['count']
print('The formula {} provided:'.format(formula)) print('The formula {} provided:'.format(formula))
print('Average difference: {}'.format(self.temperature.getAverageDifference())) print('Average difference: {}'.format(self.temperature.getAverageDifference()))
pprint(answers)
return answers return answers
def run_forever(self, initial, modified, target): def run_forever(self, initial, modified, target):

53
copycat/problem.py Normal file
View File

@ -0,0 +1,53 @@
from .copycat import Copycat
from pprint import pprint
class Problem:
def __init__(self, initial, modified, target, iterations, distributions, formulas=None):
self.initial = initial
self.modified = modified
self.target = target
self.iterations = iterations
self.distributions = distributions
self.formulas = formulas
if formulas is not None:
assert hasattr(Copycat().workspace, 'temperature')
def test(self, comparison):
print('-' * 120)
print('Testing copycat problem: {} : {} :: {} : _'.format(self.initial,
self.modified,
self.target))
print('expected:')
pprint(self.distributions)
actual = self.solve()
print('actual:')
pprint(actual)
comparison(actual, self.distributions)
print('-' * 120)
def solve(self):
copycat = Copycat()
answers = dict()
if self.formulas == None:
if hasattr(copycat.workspace, 'temperature'):
formula = copycat.workspace.temperature.getAdj()
else:
formula = None
answers[formula] = copycat.run(self.initial,
self.modified,
self.target,
self.iterations)
else:
for formula in self.formulas:
copycat.temperature.useAdj(formula)
answers[formulas] = copycat.run(self.initial,
self.modified,
self.target,
self.iterations)
return answers
def generate(self):
self.distributions = self.solve()

View File

@ -147,5 +147,8 @@ class Temperature(object):
print('Changing to adjustment formula {}'.format(adj)) print('Changing to adjustment formula {}'.format(adj))
self.adjustmentType = adj self.adjustmentType = adj
def getAdj(self):
return self.adjustmentType
def adj_formulas(self): def adj_formulas(self):
return self._adjustmentFormulas.keys() return self._adjustmentFormulas.keys()

View File

@ -5,7 +5,7 @@ import argparse
import sys import sys
from pprint import pprint from pprint import pprint
from copycat import Copycat from copycat import Problem
# TODO: update test cases to use entropy # TODO: update test cases to use entropy
@ -26,59 +26,39 @@ _chiSquared_table = {
class ChiSquaredException(Exception): class ChiSquaredException(Exception):
pass pass
def chi_squared(actual, expected): def chi_squared(actualDict, expectedDict):
answerKeys = set(list(actual.keys()) + list(expected.keys())) for key in expectedDict.keys():
degreesFreedom = len(answerKeys) assert key in actualDict, 'The key {} was not tested'.format(key)
chiSquared = 0 actual = actualDict[key]
expected = expectedDict[key]
get_count = lambda k, d : d[k]['count'] if k in d else 0 answerKeys = set(list(actual.keys()) + list(expected.keys()))
degreesFreedom = len(answerKeys)
chiSquared = 0
for k in answerKeys: get_count = lambda k, d : d[k]['count'] if k in d else 0
E = get_count(k, expected)
O = get_count(k, actual)
if E == 0:
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
else:
chiSquared += (O - E) ** 2 / E
if chiSquared >= _chiSquared_table[degreesFreedom]: for k in answerKeys:
raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' + E = get_count(k, expected)
'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom)) O = get_count(k, actual)
if E == 0:
print('Warning! Expected 0 counts of {}, but got {}'.format(k, O))
else:
chiSquared += (O - E) ** 2 / E
class AnswerDistribution: if chiSquared >= _chiSquared_table[degreesFreedom]:
def __init__(self, initial, modified, target, iterations, distribution): raise ChiSquaredException('Significant difference between expected and actual answer distributions: \n' +
self.initial = initial 'Chi2 value: {} with {} degrees of freedom'.format(chiSquared, degreesFreedom))
self.modified = modified
self.target = target
self.iterations = iterations
self.distribution = distribution
def test(self): def generate():
print('expected:')
pprint(self.distribution)
actual = Copycat().run(self.initial,
self.modified,
self.target,
self.iterations)
print('actual:')
pprint(actual)
chi_squared(actual, self.distribution)
def generate(self):
self.distribution = Copycat().run(self.initial,
self.modified,
self.target,
self.iterations)
def generate(self):
print('Generating distributions for new file') print('Generating distributions for new file')
iterations = 30 iterations = 30
distributions = [ distributions = [
AnswerDistribution('abc', 'abd', 'efg', iterations, None), Problem('abc', 'abd', 'efg', iterations, None),
AnswerDistribution('abc', 'abd', 'ijk', iterations, None), Problem('abc', 'abd', 'ijk', iterations, None),
AnswerDistribution('abc', 'abd', 'xyz', iterations, None), Problem('abc', 'abd', 'xyz', iterations, None),
AnswerDistribution('abc', 'abd', 'ijkk', iterations, None), Problem('abc', 'abd', 'ijkk', iterations, None),
AnswerDistribution('abc', 'abd', 'mrrjjj', iterations, None)] Problem('abc', 'abd', 'mrrjjj', iterations, None)]
for distribution in distributions: for distribution in distributions:
distribution.generate() distribution.generate()
@ -95,7 +75,7 @@ class TestCopycat(unittest.TestCase):
self.longMessage = True # new in Python 2.7 self.longMessage = True # new in Python 2.7
def test(self): def test(self):
print(TestCopycat.Filename) print('Testing copycat with input file: {}'.format(TestCopycat.Filename))
try: try:
with open(TestCopycat.Filename, 'rb') as infile: with open(TestCopycat.Filename, 'rb') as infile:
distributions = pickle.load(infile) distributions = pickle.load(infile)
@ -105,7 +85,7 @@ class TestCopycat(unittest.TestCase):
distributions = generate() distributions = generate()
for distribution in distributions: for distribution in distributions:
distribution.test() distribution.test(chi_squared)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -116,11 +96,11 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
# TODO: Go do something with args.input and args.filename # TODO: Go do something with args.input and args.filename
TestCopycat.Filename = args.filename
if args.generate: if args.generate:
generate() generate()
TestCopycat.Filename = args.filename
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone) # Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
sys.argv[1:] = args.unittest_args sys.argv[1:] = args.unittest_args
unittest.main() unittest.main()