diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cdb6608 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.pyc +*.pyo +*~ \ No newline at end of file diff --git a/test/test_workerpool.py b/test/test_workerpool.py index 941026e..dcb865f 100644 --- a/test/test_workerpool.py +++ b/test/test_workerpool.py @@ -8,6 +8,22 @@ class TestWorkerPool(unittest.TestCase): + def __init__(self, *args, **kwargs): + self._pools = [] + super(TestWorkerPool, self).__init__(*args, **kwargs) + + def run(self, *args, **kwargs): + try: + super(TestWorkerPool, self).run(*args, **kwargs) + finally: + for pool in self._pools: + pool.shutdown() + + def get_workerpool(self, *args): + p = workerpool.WorkerPool(*args) + self._pools.append(p) + return p + def double(self, i): return i * 2 @@ -16,48 +32,42 @@ def add(self, *args): def test_map(self): "Map a list to a method to a pool of two workers." - pool = workerpool.WorkerPool(2) + pool = self.get_workerpool(2) r = pool.map(self.double, [1, 2, 3, 4, 5]) - self.assertEquals(set(r), {2, 4, 6, 8, 10}) - pool.shutdown() + self.assertEquals(r, [2, 4, 6, 8, 10]) def test_map_multiparam(self): "Test map with multiple parameters." - pool = workerpool.WorkerPool(2) + pool = self.get_workerpool(2) r = pool.map(self.add, [1, 2, 3], [4, 5, 6]) - self.assertEquals(set(r), {5, 7, 9}) - pool.shutdown() + self.assertEquals(r, [5, 7, 9]) def test_wait(self): "Make sure each task gets marked as done so pool.wait() works." - pool = workerpool.WorkerPool(5) + pool = self.get_workerpool(5) q = Queue() for i in xrange(100): pool.put(workerpool.SimpleJob(q, sum, [range(5)])) pool.wait() - pool.shutdown() def test_init_size(self): - pool = workerpool.WorkerPool(1) + pool = self.get_workerpool(1) self.assertEquals(pool.size(), 1) - pool.shutdown() def test_shrink(self): - pool = workerpool.WorkerPool(1) + pool = self.get_workerpool(1) pool.shrink() self.assertEquals(pool.size(), 0) - pool.shutdown() def test_grow(self): - pool = workerpool.WorkerPool(1) + pool = self.get_workerpool(1) pool.grow() self.assertEquals(pool.size(), 2) - pool.shutdown() def test_changesize(self): "Change sizes and make sure pool doesn't work with no workers." - pool = workerpool.WorkerPool(5) + pool = self.get_workerpool(5) for i in xrange(5): pool.grow() self.assertEquals(pool.size(), 10) @@ -77,4 +87,3 @@ def test_changesize(self): else: assert False, "Something returned a result, even though we are" "expecting no workers." - pool.shutdown() diff --git a/workerpool/jobs.py b/workerpool/jobs.py index 381dfda..f532107 100644 --- a/workerpool/jobs.py +++ b/workerpool/jobs.py @@ -31,10 +31,10 @@ class SimpleJob(Job): list, the method will execute r = method(*args) or r = method(**args), depending on args' type, and perform result.put(r). """ - def __init__(self, result, method, args=[]): + def __init__(self, result, method, args=None): self.result = result self.method = method - self.args = args + self.args = args or [] def run(self): if isinstance(self.args, list) or isinstance(self.args, tuple): @@ -46,3 +46,22 @@ def run(self): def _return(self, r): "Handle return value by appending to the ``self.result`` queue." self.result.put(r) + + +class OrderedSimpleJob(SimpleJob): + """ + Special job used in `pool.map` used to retain order of arguments + and results. + """ + def __init__(self, index, result, method, args=None): + self.index = index + self.result = result + self.method = method + self.args = args or [] + + def _return(self, r): + """ + Returns the output of the job in addition to the index it + should have in the results list. + """ + self.result.put((self.index, r)) diff --git a/workerpool/pools.py b/workerpool/pools.py index 803beb6..0f0ce9f 100644 --- a/workerpool/pools.py +++ b/workerpool/pools.py @@ -12,7 +12,7 @@ from QueueWrapper import Queue from workers import Worker -from jobs import SimpleJob, SuicideJob +from jobs import OrderedSimpleJob, SuicideJob __all__ = ['WorkerPool', 'default_worker_factory'] @@ -95,16 +95,16 @@ def map(self, fn, *seq): "block until done." results = Queue() args = zip(*seq) - for seq in args: - j = SimpleJob(results, fn, seq) + for i, seq in enumerate(args): + j = OrderedSimpleJob(i, results, fn, seq) self.put(j) # Aggregate results - r = [] - for i in xrange(len(args)): - r.append(results.get()) - - return r + self.join() + sentinel = object() + results.put(sentinel) + r = sorted(iter(results.get, sentinel)) + return [x[1] for x in r] def wait(self): "DEPRECATED: Use join() instead."