Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?
-
1`take()`, `skip()`, and `shard()` all have their own problems. I just posted my answer over [here](https://stackoverflow.com/a/58452268/5462608). I hope it better answers your question. – Nick Lee Oct 18 '19 at 13:55
8 Answers
Assuming you have all_dataset
variable of tf.data.Dataset
type:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
Test dataset now has first 1000 elements and the rest goes for training.
![](../../users/profiles/7442271.webp)
- 1,066
- 7
- 11
-
As also mentioned in [ted's answer](https://stackoverflow.com/a/51258695/5299750), adding `all_dataset.shuffle()` allows for a shuffled split. Possibly add as code comment in answer like so? `# all_dataset = all_dataset.shuffle() # in case you want a shuffled split` – Christian Steinmeyer Oct 02 '20 at 14:15
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love
![](../../users/profiles/3867406.webp)
- 9,150
- 3
- 47
- 83
-
3Thank you very much @ted! Is there a way to divide the dataset in a stratified way? Or, alternatively, how can we have an idea of the class proportions (suppose a binary problem) after the train/val/test split? Thanks a lot in advance! – Tommaso Di Noto Aug 27 '19 at 13:21
-
1Have a look at this blogpost I wrote; eventhough it's for multilabel datasets, should be easily usable for single label, multiclass datasets -> https://vict0rs.ch/2018/06/17/multilabel-text-classification-tensorflow/ – ted Aug 27 '19 at 13:50
-
2This causes my train,validation and test datasets to have overlap between them. Is this supposed to happen and not a big deal? I would assume it's not a good idea to have the model train on validation and test data. – c_student Jan 24 '20 at 19:51
-
2@c_student I had the same problem and I figured out what I was missing: when you shuffle use the option `reshuffle_each_iteration=False` otherwise elements could be repeated in train, test and val – xdola Apr 15 '20 at 17:57
-
1This is very true @xdola, and in particular when using `list_files` you should use `shuffle=False` and then shuffle with the `.shuffle` with `reshuffle_each_iteration=False`. – Zaccharie Ramzi May 27 '20 at 09:48
-
Most of the answers here use take()
and skip()
, which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.
Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.
To accomplish this, lets start with a simple dataset of 0-9:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
So the first dataset.window(split, split + 1)
says to grab split
number (3) of elements, then advance split + 1
elements, and repeat. That + 1
effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds)
is because window()
returns the results in batches, which we don't want. So we flatten it back out.
Then for the validation data we first skip(split)
, which skips over the first split
number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1)
then grabs 1 element, advances split + 1
(4), and repeats.
Note on nested datasets:
The above example works well for simple datasets, but flat_map()
will generate an error if the dataset is nested. To address this, you can swap out the flat_map()
with a more complicated version that can handle both simple and nested datasets:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
![](../../users/profiles/486035.webp)
- 4,436
- 1
- 27
- 27
-
Doesn't `window` just use `skip` under the hood? How does is the disadvantage `The other disadvantage is that with skip() it has to read, and then discard, all the skipped records, which if your data source is slow means you might have a large spool-up time before results are emitted.` adressed? – Frederik Bode Mar 03 '20 at 09:06
-
1If you have a dataset of 1000 records, and you want a 10% for validation, you would have to skip the first 900 records before a single validation record is emitted. With this solution, it only has to skip 9 records. It does end up skipping the same amount overall, but if you use `dataset.prefetch()`, it can read in the background while doing other things. The difference is just saving the initial spool-up time. – phemmer Mar 03 '20 at 09:11
-
Thinking about it a bit more, and I removed the statement. There's probably a dozen ways to solve that problem, and it's probably minute, if present at all, for most people. – phemmer Mar 03 '20 at 09:40
-
1You should probably set the *without knowing the dataset size beforehand* to boldface, or like a header or something, it's pretty important. This should really be the accepted answer, as it fits into the premise of `tf.data.Dataset` treating data like infinite streams. – Frederik Bode Mar 03 '20 at 09:43
@ted's answer will cause some overlap. Try this.
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
use code below to test.
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
![](../../users/profiles/8264743.webp)
- 123
- 2
- 6
-
1I love how everyone assumes you know the `full_ds_size` but no one explains how to find it – Bersan Mar 30 '21 at 15:28
Now Tensorflow doesn't contain any tools for that.
You could use sklearn.model_selection.train_test_split
to generate train/eval/test dataset, then create tf.data.Dataset
respectively.
![](../../users/profiles/5923340.webp)
- 179
- 1
- 4
You can use shard
:
dataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
![](../../users/profiles/868331.webp)
- 5,605
- 3
- 36
- 58
In case size of the dataset is known:
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
Example:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
![](../../users/profiles/5922329.webp)
- 1,160
- 12
- 17
Can't comment, but above answer has overlap and is incorrect. Set BUFFER_SIZE to DATASET_SIZE for perfect shuffle. Try different sized val/test size to verify. Answer should be:
DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)
![](../../users/profiles/12555668.webp)
- 7
- 2