Do a pile of work

Wednesday 19 August 2020

I had a large pile of data to feed through an expensive function. The concurrent.futures module in the Python standard library has always worked well for me as a simple way to farm out work across threads or processes.

For example, if my work function is “workfn”, and it takes tuples of arguments as produced by “argsfn()”, this is how you could run them all:

for args in argsfn():

This is how you would run them on a number of threads:

import concurrent.futures as cf

with cf.ThreadPoolExecutor(max_workers=nthreads) as executor:
    for args in argsfn():
        executor.submit(workfn, *args)

But this will generate all of the arguments up-front. If I have millions of work invocations, this could be a problem. I wanted a way to feed the tasks in as they are processed, to keep the queue small. And I wanted a progress bar.

I started from this Stack Overflow answer, added in tqdm for a progress bar, and made this:

import concurrent.futures as cf
from tqdm import tqdm

def wait_first(futures):
    Wait for the first future to complete.

        (done, not_done): two sets of futures.

    return cf.wait(futures, return_when=cf.FIRST_COMPLETED)

def do_work(nthreads, argsfn, workfn):
    Do a pile of work, maybe in threads, with a progress bar.

    Two callables are provided: `workfn` is the unit of work to be done,
    many times.  Its arguments are provided by calling `argsfn`, which
    must produce a sequence of tuples.  `argsfn` will be called a few
    times, and must produce the same sequence each time.

        nthreads: the number of threads to use.
        argsfn: a callable that produces tuples, the arguments to `workfn`.
        workfn: a callable that does work.

    total = sum(1 for _ in argsfn())
    with tqdm(total=total, smoothing=0.1) as progressbar:
        if nthreads:
            limit = 2 * nthreads
            not_done = set()
            with cf.ThreadPoolExecutor(max_workers=nthreads) as executor:
                for args in argsfn():
                    if len(not_done) >= limit:
                        done, not_done = wait_first(not_done)
                    not_done.add(executor.submit(workfn, *args))
                while not_done:
                    done, not_done = wait_first(not_done)
            for args in argsfn():

There might be a better way. I don’t like the duplication of the wait_first call, but this works, and produces the right results.

BTW: my actual work function spawned subprocesses, which is why a thread pool worked to give me parallelism. A pure-Python work function wouldn’t get a speed-up this way, but a ProcessPoolExecutor could help.


lemon24 12:29 PM on 21 Aug 2020

Here's a different take using multiprocessing.dummy.

I really, really wanted it to be shorter, but without with_backpressure, the imap_unordered() call will consume the whole iterator immediately (it's lazy only in some ways, see for details).

Similar to your example, pure-Python code can be sped up by using multiprocessing.Pool instead.

from multiprocessing.dummy import Pool
from queue import Queue
from tqdm import tqdm

def with_backpressure(it, maxsize):
    """Consume at most maxsize elements at a time from an iterator,
    then block until additional elements can be consumed.

        (wrapped iterator, advance function):
            An iterator with the same elements as the original one,
            and an advance() function; call advance() when another 
            element can be consumed.

    queue = Queue(maxsize)
    sentinel = object()
    it = iter(it)

    def advance():
        queue.put(next(it, sentinel))

    # "prime" the queue
    for _ in range(maxsize):

    return iter(queue.get, sentinel), advance

def do_work(nthreads, argsfn, workfn):
    total = sum(1 for _ in argsfn())

    def worker(args):
        return workfn(*args)

    pool = Pool(nthreads)
    args, advance = with_backpressure(argsfn(), nthreads * 2)
    progressbar = tqdm(total=total, smoothing=0.1)
    with pool, progressbar:
        for _ in pool.imap_unordered(worker, args):

if __name__ == '__main__':
    import time
    def argsfn():
        for i in range(10):
            yield 1,

    do_work(3, argsfn, time.sleep)

