Fine tune (or train) a pretrained Transformer model for your given training labelled data x
and y
. The prediction task can be classification (if regression
is FALSE
, default) or regression (if regression
is TRUE
).
Usage
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
# S3 method for default
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
# S3 method for corpus
grafzahl(
x,
y = NULL,
model_name = "xlm-roberta-base",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
textmodel_transformer(...)
# S3 method for character
grafzahl(
x,
y = NULL,
model_name = "xlmroberta",
regression = FALSE,
output_dir,
cuda = detect_cuda(),
num_train_epochs = 4,
train_size = 0.8,
args = NULL,
cleanup = TRUE,
model_type = NULL,
manual_seed = floor(runif(1, min = 1, max = 721831)),
verbose = TRUE
)
Arguments
- x
the corpus or character vector of texts on which the model will be trained. Depending on
train_size
, some texts will be used for cross-validation.- y
training labels. It can either be a single string indicating which docvars of the corpus is the training labels; a vector of training labels in either character or factor; or
NULL
if the corpus contains exactly one column in docvars and that column is the training labels. Ifx
is a character vector,y
must be a vector of the same length.- model_name
string indicates either 1) the model name on Hugging Face website; 2) the local path of the model
- regression
logical, if
TRUE
, the task is regression, classification otherwise.- output_dir
string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten.
- cuda
logical, whether to use CUDA, default to
detect_cuda()
.- num_train_epochs
numeric, if
train_size
is not exactly 1.0, the maximum number of epochs to try in the "early stop" regime will be this number times 5 (i.e. 4 * 5 = 20 by default). Iftrain_size
is exactly 1.0, the number of epochs is exactly that.- train_size
numeric, proportion of data in
x
andy
to be used actually for training. The rest will be used for cross validation.- args
list, additionally parameters to be used in the underlying simple transformers
- cleanup
logical, if
TRUE
, theruns
directory generated will be removed when the training is done- model_type
a string indicating model_type of the input model. If
NULL
, it will be inferred frommodel_name
. Supported model types are available in supported_model_types.- manual_seed
numeric, random seed
- verbose
logical, if
TRUE
, debug messages will be displayed- ...
paramters pass to
grafzahl()
Value
a grafzahl
S3 object with the following items
- call
original function call
- input_data
input_data for the underlying python function
- output_dir
location of the output model
- model_type
model type
- model_name
model name
- regression
whether or not it is a regression model
- levels
factor levels of y
- manual_seed
random seed
- meta
metadata about the current session
Examples
if (detect_conda() && interactive()) {
library(quanteda)
set.seed(20190721)
## Using the default cross validation method
model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base")
predict(model1)
## Using LIME
input <- corpus(ecosent, text_field = "headline")
training_corpus <- corpus_subset(input, !gold)
model2 <- grafzahl(x = training_corpus,
y = "value",
model_name = "GroNLP/bert-base-dutch-cased")
test_corpus <- corpus_subset(input, gold)
predicted_sentiment <- predict(model2, test_corpus)
require(lime)
sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken",
"Aandelenbeurzen zetten koersopmars voort")
explainer <- lime(training_corpus, model2)
explanations <- explain(sentences, explainer, n_labels = 1,
n_features = 2)
plot_text_explanations(explanations)
}