library(torch)
library(luz)
Luz is a higher level API for torch that is designed to be highly flexible by providing a layered API that allows it to be useful no matter the level of control your need for your training loop.
In the getting started vignette we have seen the basics of luz and how to quickly modify parts of the training loop using callbacks and custom metrics. In this document we will describe how luz allows the user to get fine-grained control of the training loop.
Apart from the use of callbacks, there are three more ways that you can use luz (depending on how much control you need):
Multiple optimizers or losses: You might be
optimizing two loss functions each with its own optimizer, but you still
don’t want to modify the backward()
-
zero_grad()
and step()
calls. This is common
in models like GANs (Generative Adversarial Networks) when you have
competing neural networks trained with different losses and
optimizers.
Fully flexible steps: You might want to be in
control of how to call backward()
,
zero_grad()
and step()
. You might also want to
have more control of gradient computation. For example, you might want
to use ‘virtual batch sizes’, where you accumulate the gradients for a
few steps before updating the weights.
Completely flexible loops: Your training loop
can be anything you want but you still want to use luz to handle device
placement of the dataloaders, optimizers and models. See
vignette("accelerator")
.
Let’s consider a simplified version of the net
that we
implemented in the getting started vignette:
<- nn_module(
net "Net",
initialize = function() {
$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
self
},forward = function(x) {
%>%
x $fc1() %>%
selfnnf_relu() %>%
$fc2()
self
} )
Using the highest level of luz API we would fit it using:
<- net %>%
fitted setup(
loss = nn_cross_entropy_loss(),
optimizer = optim_adam,
metrics = list(
luz_metric_accuracy
)%>%
) fit(train_dl, epochs = 10, valid_data = test_dl)
Suppose we want to do an experiment where we train the first fully
connected layer using a learning rate of 0.1 and the second one using a
learning rate of 0.01. We will minimize the same
nn_cross_entropy_loss()
for both, but for the first layer
we want to add L1 regularization on the weights.
In order to use luz for this, we will implement two methods in the
net
module:
set_optimizers
: returns a named list of optimizers
depending on the ctx
.
loss
: computes the loss depending on the selected
optimizer.
Let’s go to the code:
<- nn_module(
net "Net",
initialize = function() {
$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
self
},forward = function(x) {
%>%
x $fc1() %>%
selfnnf_relu() %>%
$fc2()
self
},set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
list(
opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
)
},loss = function(input, target) {
<- ctx$model(input)
pred
if (ctx$opt_name == "opt_fc1")
nnf_cross_entropy(pred, target) + torch_norm(self$fc1$weight, p = 1)
else if (ctx$opt_name == "opt_fc2")
nnf_cross_entropy(pred, target)
} )
Notice that the model optimizers will be initialized according to the
set_optimizers()
method’s return value (a list). In this
case, we are initializing the optimizers using different model
parameters and learning rates.
The loss()
method is responsible for computing the loss
that will then be back-propagated to compute gradients and update the
weights. This loss()
method can access the ctx
object that will contain an opt_name
field, describing
which optimizer is currently being used. Note that this function will be
called once for each optimizer for each training and validation step.
See help("ctx")
for complete information about the context
object.
We can finally setup
and fit
this module,
however we no longer need to specify optimizers and loss functions.
<- net %>%
fitted setup(metrics = list(luz_metric_accuracy)) %>%
fit(train_dl, epochs = 10, valid_data = test_dl)
Now let’s re-implement this same model using the slightly more flexible approach of overriding the training and validation step.
Instead of implementing the loss()
method, we can
implement the step()
method. This allows us to flexibly
modify what happens when training and validating for each batch in the
dataset. You are now responsible for updating the weights by stepping
the optimizers and back-propagating the loss.
<- nn_module(
net "Net",
initialize = function() {
$fc1 <- nn_linear(100, 50)
self$fc1 <- nn_linear(50, 10)
self
},forward = function(x) {
%>%
x $fc1() %>%
selfnnf_relu() %>%
$fc2()
self
},set_optimizers = function(lr_fc1 = 0.1, lr_fc2 = 0.01) {
list(
opt_fc1 = optim_adam(self$fc1$parameters, lr = lr_fc1),
opt_fc2 = optim_adam(self$fc2$parameters, lr = lr_fc2)
)
},step = function() {
$loss <- list()
ctxfor (opt_name in names(ctx$optimizers)) {
<- ctx$model(ctx$input)
pred <- ctx$optimizers[[opt_name]]
opt <- nnf_cross_entropy(pred, target)
loss
if (opt_name == "opt_fc1") {
# we have L1 regularization in layer 1
<- nnf_cross_entropy(pred, target) +
loss torch_norm(self$fc1$weight, p = 1)
}
if (ctx$training) {
$zero_grad()
opt$backward()
loss$step()
opt
}
$loss[[opt_name]] <- loss$detach()
ctx
}
} )
The important things to notice here are:
The step()
method is used for both training and
validation. You need to be careful to only modify the weights when
training. Again, you can get complete information regarding the context
object using help("ctx")
.
ctx$optimizers
is a named list holding each
optimizer that was created when the set_optimizers()
method
was called.
You need to manually track the losses by saving saving them in a
named list in ctx$loss
. By convention, we use the same name
as the optimizer it refers to. It is good practice to
detach()
them before saving to reduce memory
usage.
Callbacks that would be called inside the default
step()
method like on_train_batch_after_pred
,
on_train_batch_after_loss
, etc, won’t be automatically
called. You can still cal them manually by adding
ctx$call_callbacks("<callback name>")
inside your
training step. See the code for fit_one_batch()
and
valid_one_batch
to find all the callbacks that won’t be
called.
In this article you learned how to customize the step()
of your training loop using luz layered functionality.
Luz also allows more flexible modifications of the training loop
described in the Accelerator vignette
(vignette("accelerator")
).
You should now be able to follow the examples marked with the ‘intermediate’ and ‘advanced’ category in the examples gallery.