52

Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?

desertnaut
  • 46,107
  • 19
  • 109
  • 140
Dani
  • 567
  • 1
  • 4
  • 8
  • 1
    `take()`, `skip()`, and `shard()` all have their own problems. I just posted my answer over [here](https://stackoverflow.com/a/58452268/5462608). I hope it better answers your question. – Nick Lee Oct 18 '19 at 13:55

8 Answers8

67

Assuming you have all_dataset variable of tf.data.Dataset type:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

Test dataset now has first 1000 elements and the rest goes for training.

apatsekin
  • 1,066
  • 7
  • 11
  • As also mentioned in [ted's answer](https://stackoverflow.com/a/51258695/5299750), adding `all_dataset.shuffle()` allows for a shuffled split. Possibly add as code comment in answer like so? `# all_dataset = all_dataset.shuffle() # in case you want a shuffled split` – Christian Steinmeyer Oct 02 '20 at 14:15
41

You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.


Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love

ted
  • 9,150
  • 3
  • 47
  • 83
  • 3
    Thank you very much @ted! Is there a way to divide the dataset in a stratified way? Or, alternatively, how can we have an idea of the class proportions (suppose a binary problem) after the train/val/test split? Thanks a lot in advance! – Tommaso Di Noto Aug 27 '19 at 13:21
  • 1
    Have a look at this blogpost I wrote; eventhough it's for multilabel datasets, should be easily usable for single label, multiclass datasets -> https://vict0rs.ch/2018/06/17/multilabel-text-classification-tensorflow/ – ted Aug 27 '19 at 13:50
  • 2
    This causes my train,validation and test datasets to have overlap between them. Is this supposed to happen and not a big deal? I would assume it's not a good idea to have the model train on validation and test data. – c_student Jan 24 '20 at 19:51
  • 2
    @c_student I had the same problem and I figured out what I was missing: when you shuffle use the option `reshuffle_each_iteration=False` otherwise elements could be repeated in train, test and val – xdola Apr 15 '20 at 17:57
  • 1
    This is very true @xdola, and in particular when using `list_files` you should use `shuffle=False` and then shuffle with the `.shuffle` with `reshuffle_each_iteration=False`. – Zaccharie Ramzi May 27 '20 at 09:48
  • @xdola, Thank you for your comment, a potential disaster avoided! – Tim Mironov Jan 17 '21 at 16:01
11

Most of the answers here use take() and skip(), which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.

Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.

To accomplish this, lets start with a simple dataset of 0-9:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]

So the first dataset.window(split, split + 1) says to grab split number (3) of elements, then advance split + 1 elements, and repeat. That + 1 effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds) is because window() returns the results in batches, which we don't want. So we flatten it back out.

Then for the validation data we first skip(split), which skips over the first split number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1) then grabs 1 element, advances split + 1 (4), and repeats.

 

Note on nested datasets:
The above example works well for simple datasets, but flat_map() will generate an error if the dataset is nested. To address this, you can swap out the flat_map() with a more complicated version that can handle both simple and nested datasets:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
phemmer
  • 4,436
  • 1
  • 27
  • 27
  • Doesn't `window` just use `skip` under the hood? How does is the disadvantage `The other disadvantage is that with skip() it has to read, and then discard, all the skipped records, which if your data source is slow means you might have a large spool-up time before results are emitted.` adressed? – Frederik Bode Mar 03 '20 at 09:06
  • 1
    If you have a dataset of 1000 records, and you want a 10% for validation, you would have to skip the first 900 records before a single validation record is emitted. With this solution, it only has to skip 9 records. It does end up skipping the same amount overall, but if you use `dataset.prefetch()`, it can read in the background while doing other things. The difference is just saving the initial spool-up time. – phemmer Mar 03 '20 at 09:11
  • Thinking about it a bit more, and I removed the statement. There's probably a dozen ways to solve that problem, and it's probably minute, if present at all, for most people. – phemmer Mar 03 '20 at 09:40
  • 1
    You should probably set the *without knowing the dataset size beforehand* to boldface, or like a header or something, it's pretty important. This should really be the accepted answer, as it fits into the premise of `tf.data.Dataset` treating data like infinite streams. – Frederik Bode Mar 03 '20 at 09:43
6

@ted's answer will cause some overlap. Try this.

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

use code below to test.

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)
Hank
  • 123
  • 2
  • 6
  • 1
    I love how everyone assumes you know the `full_ds_size` but no one explains how to find it – Bersan Mar 30 '21 at 15:28
3

Now Tensorflow doesn't contain any tools for that.
You could use sklearn.model_selection.train_test_split to generate train/eval/test dataset, then create tf.data.Dataset respectively.

Lunar_one
  • 179
  • 1
  • 4
1

You can use shard:

dataset = dataset.shuffle()  # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)

See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard

Yoav
  • 5,605
  • 3
  • 36
  • 58
0

In case size of the dataset is known:

from typing import Tuple
import tensorflow as tf

def split_dataset(dataset: tf.data.Dataset, 
                  dataset_size: int, 
                  train_ratio: float, 
                  validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
    assert (train_ratio + validation_ratio) < 1

    train_count = int(dataset_size * train_ratio)
    validation_count = int(dataset_size * validation_ratio)
    test_count = dataset_size - (train_count + validation_count)

    dataset = dataset.shuffle(dataset_size)

    train_dataset = dataset.take(train_count)
    validation_dataset = dataset.skip(train_count).take(validation_count)
    test_dataset = dataset.skip(validation_count + train_count).take(test_count)

    return train_dataset, validation_dataset, test_dataset

Example:

size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2

ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
Daniel Braun
  • 1,160
  • 12
  • 17
-2

Can't comment, but above answer has overlap and is incorrect. Set BUFFER_SIZE to DATASET_SIZE for perfect shuffle. Try different sized val/test size to verify. Answer should be:

DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)