11

I'm using the Tensorflow Dataset API to prepare my data for input into my network. During this process, I have some custom Python functions which are mapped to the dataset using tf.py_function. I want to be able to debug the data going into these functions and what happens to that data inside these functions. When a py_function is called, this calls back to the main Python process (according to this answer). Since this function is in Python, and in the main process, I would expect a regular IDE breakpoint to be able stop in this process. However, this doesn't seem to be the case (example below where the breakpoint does not halt execution). Is there a way to drop into a breakpoint within a py_function used by the Dataset map?

Example where the breakpoint does not halt execution

import tensorflow as tf

def add_ten(example, label):
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))
golmschenk
  • 9,361
  • 17
  • 69
  • 117
  • Which TensorFlow version are you using? it seems to be working on 1.12.0 P.S.: make sure the python function returns the correct type (e.g. ´return np.int32(example_plus_ten), np.int32(label)´ ) – gab Dec 14 '19 at 15:27
  • 1
    @gabriele: 2.0 (currently latest stable version). – golmschenk Dec 15 '19 at 00:01

1 Answers1

9

Tensorflow 2.0 implementation of tf.data.Dataset opens a C threads for each call without notifying your debugger. Use pydevd's to manually set a tracing function that will connect to your default debugger server and start feeding it the debug data.

import pydevd
pydevd.settrace()

Example with your code:

import tensorflow as tf
import pydevd

def add_ten(example, label):
    pydevd.settrace(suspend=False)
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))

Note: If you are using IDE which already bundles pydevd (such as PyDev or PyCharm) you do not have to install pydevd separately, it will picked up during the debug session.

Daniel Braun
  • 1,160
  • 12
  • 17
  • Hi, I am using flat_map in TF 2.2. and my breakpoint is hit in some strange "interpreted/compiled" temporary code (some `tmpcXXXXXXX.py` file) with `pydevd.settrace(suspend=True)`...do you know? Is there any possibility to avoid that whilst debugging (using pycharm)? thx you in advance, BR gfkri – gfkri Oct 20 '20 at 08:02
  • @DanielBraun: I am using PyCharm with tensorflow 2.4.1 (in Windows) Setting a breakpoint in the example above does not work. However, if I understand you correctly it should stop since pycharm already bundles pydevd, or? – gebbissimo Apr 09 '21 at 08:16