Dataset API

Overview

We can access the TensorFlow Dataset API via the tfdatasets package, which enables us to create scalable input pipelines that can be used with tfestimators. In this vignette, we demonstrate the capability to stream datasets stored on disk for training by building a classifier on the iris dataset.

Dataset Preparation

Let’s assume we’re given a dataset (which could be arbitrarily large) split into training and validation, and a small sample of the dataset. To simulate this scenario, we’ll create a few CSV files as follows:

set.seed(123)
train_idx <- sample(nrow(iris), nrow(iris) * 2/3)

iris_train <- iris[train_idx,]
iris_validation <- iris[-train_idx,]
iris_sample <- iris_train %>%
  head(10)

write.csv(iris_train, "iris_train.csv", row.names = FALSE)
write.csv(iris_validation, "iris_validation.csv", row.names = FALSE)
write.csv(iris_sample, "iris_sample.csv", row.names = FALSE)

Estimator Construction

We construct the classifier as usual – see Estimator Basics for details on feature columns and creating estimators.

library(tfestimators)
response <- "Species"
features <- setdiff(names(iris), response)
feature_columns <- feature_columns(
  column_numeric(features)
)

classifier <- dnn_classifier(
  feature_columns = feature_columns,
  hidden_units = c(16, 32, 16),
  n_classes = 3,
  label_vocabulary = c("setosa", "virginica", "versicolor")
)

Input Function

The creation of the input function is similar to the in-memory case. However, instead of passing data frames or matrices to iris_input_fn(), we pass TensorFlow dataset objects which are internally iterators of the dataset files.

iris_input_fn <- function(data) {
  input_fn(data, features = features, response = response)
}

iris_spec <- csv_record_spec("iris_sample.csv")
iris_train <- text_line_dataset(
  "iris_train.csv", record_spec = iris_spec) %>%
  dataset_batch(10) %>% 
  dataset_repeat(10)
iris_validation <- text_line_dataset(
  "iris_validation.csv", record_spec = iris_spec) %>%
  dataset_batch(10) %>%
  dataset_repeat(1)

The csv_record_spec() function is a helper function that creates a specification from a sample file; the returned specification is required by the text_line_dataset() function to parse the files. There are many transformations available for dataset objects, but here we just demonstrate dataset_batch() and dataset_repeat() which control the batch size and how many times we iterate through the dataset files, respectively.

Training and Evaluation

Once the input functions and datasets are defined, the training and evaluation interface is exactly the same as in the in-memory case.

history <- train(classifier, input_fn = iris_input_fn(iris_train))
plot(history)
predictions <- predict(classifier, input_fn = iris_input_fn(iris_validation))
predictions
evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_validation))
evaluation

Learning More

See the documetnation for the tfdatasets package for additional details on using TensorFlow datasets.