Fine tune (or train) a pretrained Transformer model for your given training labelled data x and y. The prediction task can be classification (if regression is FALSE, default) or regression (if regression is TRUE).
Usage
grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)
# Default S3 method
grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)
# S3 method for class 'corpus'
grafzahl(
  x,
  y = NULL,
  model_name = "xlm-roberta-base",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)
textmodel_transformer(...)
# S3 method for class 'character'
grafzahl(
  x,
  y = NULL,
  model_name = "xlmroberta",
  regression = FALSE,
  output_dir,
  cuda = detect_cuda(),
  num_train_epochs = 4,
  train_size = 0.8,
  args = NULL,
  cleanup = TRUE,
  model_type = NULL,
  manual_seed = floor(runif(1, min = 1, max = 721831)),
  verbose = TRUE
)Arguments
- x
- the quanteda::corpus or character vector of texts on which the model will be trained. Depending on - train_size, some texts will be used for cross-validation.
- y
- training labels. It can either be a single string indicating which quanteda::docvars of the quanteda::corpus is the training labels; a vector of training labels in either character or factor; or - NULLif the quanteda::corpus contains exactly one column in quanteda::docvars and that column is the training labels. If- xis a character vector,- ymust be a vector of the same length.
- model_name
- string indicates either 1) the model name on Hugging Face website; 2) the local path of the model 
- regression
- logical, if - TRUE, the task is regression, classification otherwise.
- output_dir
- string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten. 
- cuda
- logical, whether to use CUDA, default to - detect_cuda().
- num_train_epochs
- numeric, if - train_sizeis not exactly 1.0, the maximum number of epochs to try in the "early stop" regime will be this number times 5 (i.e. 4 * 5 = 20 by default). If- train_sizeis exactly 1.0, the number of epochs is exactly that.
- train_size
- numeric, proportion of data in - xand- yto be used actually for training. The rest will be used for cross validation.
- args
- list, additionally parameters to be used in the underlying simple transformers 
- cleanup
- logical, if - TRUE, the- runsdirectory generated will be removed when the training is done
- model_type
- a string indicating model_type of the input model. If - NULL, it will be inferred from- model_name. Supported model types are available in supported_model_types.
- manual_seed
- numeric, random seed 
- verbose
- logical, if - TRUE, debug messages will be displayed
- ...
- paramters pass to - grafzahl()
Value
a grafzahl S3 object with the following items
- call
- original function call 
- input_data
- input_data for the underlying python function 
- output_dir
- location of the output model 
- model_type
- model type 
- model_name
- model name 
- regression
- whether or not it is a regression model 
- levels
- factor levels of y 
- manual_seed
- random seed 
- meta
- metadata about the current session 
Examples
if (detect_conda() && interactive()) {
library(quanteda)
set.seed(20190721)
## Using the default cross validation method
model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base")
predict(model1)
## Using LIME
input <- corpus(ecosent, text_field = "headline")
training_corpus <- corpus_subset(input, !gold)
model2 <- grafzahl(x = training_corpus,
                 y = "value",
                 model_name = "GroNLP/bert-base-dutch-cased")
test_corpus <- corpus_subset(input, gold)
predicted_sentiment <- predict(model2, test_corpus)
require(lime)
sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken",
               "Aandelenbeurzen zetten koersopmars voort")
explainer <- lime(training_corpus, model2)
explanations <- explain(sentences, explainer, n_labels = 1,
                        n_features = 2)
plot_text_explanations(explanations)
}