First, we need to install blurr module
for Transformers integration.
reticulate::py_install('https://github.com/ohmeow/blurr',pip = TRUE)
Grab data and take 1 % for fast training:
library(fastai)
library(magrittr)
library(zeallot)
= HF_load_dataset('civil_comments', split='train[:1%]') df
Select multiple outputs/columns:
= data.table::as.data.table(df)
df
= c('severe_toxicity',
lbl_cols 'obscene',
'threat',
'insult',
'identity_attack',
'sexual_explicit')
<- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols]
df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols] df
Load distill RoBERTa:
= HF_TASKS_ALL()$SequenceClassification
task
= "distilroberta-base"
pretrained_model_name = AutoConfig()$from_pretrained(pretrained_model_name)
config $num_labels = length(lbl_cols)
config
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name,
task=task,
config=config)
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s]
Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s]
Create data blocks:
= list(
blocks HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer),
MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols)
)
= DataBlock(blocks=blocks,
dblock get_x=ColReader('text'), get_y=ColReader(lbl_cols),
splitter=RandomSplitter())
= dblock %>% dataloaders(df, bs=8)
dls
%>% one_batch() dls
[[1]]
[[1]]$input_ids
tensor([[ 0, 24268, 5257, ..., 1, 1, 1],
[ 0, 287, 4505, ..., 1, 1, 1],
[ 0, 38, 437, ..., 1, 1, 1],
...,
[ 0, 152, 1129, ..., 1, 1, 1],
[ 0, 85, 18, ..., 1, 1, 1],
[ 0, 22014, 31, ..., 1, 1, 1]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]], device='cuda:0')
[[2]]
TensorMultiCategory([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], device='cuda:0')
= HF_BaseModelWrapper(hf_model)
model
= Learner(dls,
learn
model,opt_func=partial(Adam),
loss_func=BCEWithLogitsLossFlat(),
metrics=partial(accuracy_multi(), thresh=0.2),
cbs=HF_BaseModelCallback(),
splitter=hf_splitter())
$loss_func$thresh = 0.2
learn$create_opt() # -> will create your layer groups based on your "splitter" function
learn$freeze()
learn
%>% summary() learn
See summary:
epoch train_loss valid_loss accuracy_multi time
------ ----------- ----------- --------------- ------
HF_BaseModelWrapper (Input shape: 8 x 391)
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 8 x 391 x 768 38,603,520 False
________________________________________________________________
Embedding 8 x 391 x 768 394,752 False
________________________________________________________________
Embedding 8 x 391 x 768 768 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 768 590,592 True
________________________________________________________________
Dropout 8 x 768 0 False
________________________________________________________________
Linear 8 x 6 4,614 True
________________________________________________________________
Total params: 82,123,014
Total trainable params: 615,174
Total non-trainable params: 81,507,840
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>)
Loss function: FlattenedLoss of BCEWithLogitsLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- Recorder
- ProgressCallback
- HF_BaseModelCallback
Finally, fit the model:
= learn %>% lr_find(suggestions=TRUE)
lrs
%>% fit_one_cycle(1, lr_max=1e-2) learn
epoch train_loss valid_loss accuracy_multi time
------ ----------- ----------- --------------- ------
0 0.040617 0.034286 0.993257 01:21
Predict:
$loss_func$thresh = 0.02
learn
%>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes.
learn No enchiladas for them!")
$probabilities
severe_toxicity obscene threat insult identity_attack sexual_explicit
1 9.302437e-07 0.004268706 0.0007849637 0.02687055 0.003282947 0.00232468
$labels
[1] "insult"