First, we need to install blurr module
for Transformers integration.
::py_install('ohmeow-blurr',pip = TRUE) reticulate
Grab data for binary classification:
library(fastai)
library(magrittr)
library(zeallot)
URLs_IMDB_SAMPLE()
Define task:
= HF_TASKS_AUTO()
HF_TASKS_AUTO = HF_TASKS_AUTO$SequenceClassification
task
= "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
pretrained_model_name c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task)
Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s]
Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s]
Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s]
Create Learner
with Hugging Face data blocks:
= data.table::fread('imdb_sample/texts.csv')
imdb_df
= list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock())
blocks
= DataBlock(blocks=blocks,
dblock get_x=ColReader('text'),
get_y=ColReader('label'),
splitter=ColSplitter(col='is_valid'))
= dblock %>% dataloaders(imdb_df, bs=4)
dls %>% one_batch() dls
[[1]]
[[1]]$input_ids
tensor([[ 0, 4833, 3009, ..., 1916, 6, 2],
[ 0, 1876, 13856, ..., 7, 47, 2],
[ 0, 2647, 6, ..., 6, 61, 2],
[ 0, 20, 2091, ..., 5779, 30, 2]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='cuda:0')
[[2]]
TensorCategory([0, 1, 0, 0], device='cuda:0')
Wrap model:
= HF_BaseModelWrapper(hf_model)
model
= Learner(dls,
learn
model,opt_func=partial(Adam, decouple_wd=TRUE),
loss_func=CrossEntropyLossFlat(),
metrics=accuracy,
cbs=HF_BaseModelCallback(),
splitter=hf_splitter())
$create_opt()
learn$freeze()
learn
%>% summary() learn
epoch train_loss valid_loss accuracy time
------ ----------- ----------- --------- ------
HF_BaseModelWrapper (Input shape: 4 x 512)
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 4 x 512 x 768 38,603,520 False
________________________________________________________________
Embedding 4 x 512 x 768 394,752 False
________________________________________________________________
Embedding 4 x 512 x 768 768 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 768 590,592 True
________________________________________________________________
Dropout 4 x 768 0 False
________________________________________________________________
Linear 4 x 2 1,538 True
________________________________________________________________
Total params: 124,647,170
Total trainable params: 630,530
Total non-trainable params: 124,016,640
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fd850db18c8>, decouple_wd=True)
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- Recorder
- ProgressCallback
- HF_BaseModelCallback
Train and predict:
= learn %>% fit_one_cycle(3, lr_max=1e-3)
result
%>% predict(imdb_df$text[1:4]) learn