Currently the only way to load models from python is to rewrite the model architecture in R. All the parameter names must be identical. A complete example from Python to R is shown below. This is an extension of the Serialization vignette.
An artificial neural net is implemented below in Python. Note the final line which uses torch.save().
import torch
import numpy as np
#Make up data
= np.random.rand(1000,100)
madeUpData_x = np.random.rand(1000)
madeUpData_y
#Convert to categorical
= madeUpData_y.round()
madeUpData_y
= torch.from_numpy(madeUpData_x).float()
train_py_X
= torch.from_numpy(madeUpData_y).float()
train_py_Y
#Note that this class must be replicated identically in R
class simpleMLP(torch.nn.Module):
def __init__(self):
super(simpleMLP, self).__init__()
self.modelFit = torch.nn.Sequential(
100,20),
torch.nn.Linear(
torch.nn.ReLU(),20,1),
torch.nn.Linear(
torch.nn.Sigmoid())
def forward(self, x):
=self.modelFit(x)
x
return x
= simpleMLP()
model
def modelTrainer(data_X,data_Y,model):
= torch.nn.BCELoss()
criterion = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer
for epoch in range(100):
optimizer.zero_grad()
= model(data_X)
yhat
= criterion(yhat,data_Y.unsqueeze(1))
loss
loss.backward()
optimizer.step()
= train_py_X,data_Y = train_py_Y,model = model)
modelTrainer(data_X
#-----------------------------------------------------------------
#save the model
#Note that model.state_dict() comes out as an ordered dictionary
#The code below converts to a dictionary
= dict(model.state_dict())
stateDict
#Note the argument _use_new_zipfile_serialization
="path/babyTest.pth",
torch.save(stateDict,f=True) _use_new_zipfile_serialization
Once we have a saved .pth object we can load this into R. An example use case would be training a model in Python then using Shiny to develop a GUI for predictions from a trained model.
library(torch)
#Make up some test data
#note that proper installation of torch will yield no errors when we run
#this code
<- torch_tensor(array(runif(8),dim = c(2,2,2)),dtype = torch_float64())
y
#Note the identical names between the Python class definition and our
#class definition
<- torch::nn_module(
simpleMLP "simpleMLP",
initialize = function(){
$modelFit <- nn_sequential(nn_linear(100,20),
selfnn_relu(),
nn_linear(20,1),
nn_sigmoid())
},forward = function(x){
$modelFit(x)
self
}
)
<- simpleMLP()
model
<- torch::load_state_dict("p/babyTest.pth")
state_dict $load_state_dict(state_dict)
model
#Note that the dtype set in R has to match the made up data from Python
#More generally if reading new data into R you must ensure that it matches the
#dtype that the model was trained with in Python
= torch_tensor(array(rnorm(n=1000),dim=c(10,100)),dtype=torch_float32())
newData
<- model(newData) predictMe