38

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

Where complex_img_label_generator dynamically generates images and returns a numpy array representing a (H, W, 3) image and a simple string label. The processing not something I can represent as reading from files and tf.image operations.

My question is about how to parallise the generator? How do I have N of these generators running in their own threads.

One thought was to use dataset.map with num_parallel_calls to handle the threading; but the map operates on tensors... Another thought was to create multiple generators each with it's own prefetch and somehow join them, but I can't see how I'd join N generator streams?

Any canonical examples I could follow?

Engineero
  • 10,387
  • 3
  • 41
  • 65
mat kelcey
  • 2,977
  • 2
  • 28
  • 31

3 Answers3

27

Turns out I can use Dataset.map if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map using a py_func.

Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls to from_generator :)

def pure_numpy_and_pil_complex_calculation(metadata, label):
  # some complex pil and numpy work nothing to do with tf
  ...

dataset = tf.data.Dataset.from_generator(lightweight_generator,
                                         output_types=(tf.string,   # metadata
                                                       tf.string))  # label

def wrapped_complex_calulation(metadata, label):
  return tf.py_func(func = pure_numpy_and_pil_complex_calculation,
                    inp = (metadata, label),
                    Tout = (tf.uint8,    # (H,W,3) img
                            tf.string))  # label
dataset = dataset.map(wrapped_complex_calulation,
                      num_parallel_calls=8)

dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()
Engineero
  • 10,387
  • 3
  • 41
  • 65
mat kelcey
  • 2,977
  • 2
  • 28
  • 31
  • 5
    FYI, parallelism with `tf.py_func()` might not speed things up by itself, see [this answer](https://stackoverflow.com/a/48781036/5471520). – mikkola Feb 17 '18 at 12:30
  • good point. empirically though i can say this made a huge speed up. – mat kelcey Feb 19 '18 at 10:27
  • 5
    Has TensorFlow added `num_parallel_calls` to `from_generator` since your answer? – Rylan Schaeffer Jun 09 '18 at 14:12
  • 2
    @mikkola if not speed things up, any other suggestions? thanks – crafet Aug 24 '18 at 12:43
  • do you know why did you get "huge speed up"? Does it mean that even though the point @mikkola made your code is actually running in parallel? – hipoglucido Jan 03 '19 at 08:37
  • Have you used this with estimator? I get an error complaining about output_shapes not defined when I do this with keras model estimator. Any help is appreciated! ``` from_genetator > dataset.output_shapes (TensorShape([Dimension(10), Dimension(32)]), TensorShape([Dimension(10), Dimension(1)])) after map function >> dataset.output_shapes (TensorShape(None), TensorShape(None))``` – krishnakamathk Mar 13 '19 at 20:25
  • @RylanSchaeffer generators in python can't be parallelized for a discussion see https://github.com/tensorflow/tensorflow/issues/13101#issuecomment-501196097 – Mr_and_Mrs_D Jun 12 '19 at 10:34
  • Such a setup could be beneficial in low memory environments. Maybe the speed-up was the consequence of the computer not swapping anymore. – gw0 Feb 17 '20 at 11:21
10

I am working on a from_indexable for tf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448

The advantage for from_indexable is that it can be parallelized, while a python generator cannot be parallelized.

The function from_indexable makes a tf.data.range, wraps the indexable in a generalized tf.py_func and calls map.

For those that want now a from_indexable, here the lib code

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

and here an example (Note: from_indexable has a num_parallel_calls argument)

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

Update June 10, 2018: Since https://github.com/tensorflow/tensorflow/pull/15121 is merged, the code for from_indexable simplifies to:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)
  • 2
    unfortunately doesn't stand the test of time because tf2 doesn't have contrib and py_func has been replaced by py_function which donen't have output_shapes, args, kwargs, stateful. And finally, output of py_function returns unknown shape which can' be used inside graph. – Anton Dec 14 '19 at 21:22
  • It is true that tf 2.x doesn't have contrib anymore but you can always set the shapes of tensors in a function with the `set_shape` function. You can see an example in the docs: https://www.tensorflow.org/guide/data#applying_arbitrary_python_logic – Zaccharie Ramzi Feb 10 '20 at 16:57
5

Limiting the work done in the generator to a minimum and parallelizing the expensive processing using a map is sensible.

Alternatively, you can "join" multiple generators using parallel_interleave as follows:

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use
jsimsa
  • 295
  • 2
  • 6
  • 1
    Your code is not valid python code and you did not define `ds` in the first place. – Merlin1896 Jan 18 '18 at 17:00
  • 2
    I really like this. However generator(n) should return the n-th generator, n is a Tensor here. How to get the n-th generator? – Derk May 11 '18 at 14:53
  • 1
    you can give args to from_generator now: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator – Derk Jan 04 '19 at 12:54