Spaces:
Runtime error
A newer version of the Gradio SDK is available:
5.6.0
Truncated Backpropagation Through Time (BPTT)
Truncated BPTT is a useful technique for training language models on very long sequences. Typically a long sequences is split into chunks and a language model is trained over the chunks sequentially. The LM may condition on previous chunks, but gradients only flow through the current chunk. This technique was the basis for the paper: Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, which achieved state-of-the-art language modeling results at the time of publication.
It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since
we need to iterate over the data sequentially and disable any batch shuffling
logic. The code provided in this example illustrates how to implement Truncated
BPTT in fairseq by overriding FairseqTask::get_batch_iterator
to iterate
over the data sequentially. Crucially, this example supports batching and
multi-GPU (data parallel) training.
0. Setup
First, see the general language modeling README for instructions on preprocessing the WikiText-103 data.
1. Train a Transformer-XL model on WikiText-103
We will train a 16-layer Transformer-XL model following the hyperparameters used in the original paper.
The following command assumes 4 GPUs, so that the total batch size is 60 sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
--user-dir examples/truncated_bptt \
data-bin/wikitext-103/ \
--task truncated_bptt_lm --tokens-per-sample 150 \
--batch-size 15 --max-update 200000 \
--arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
--d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
--optimizer adam --clip-norm 0.25 \
--lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \
--log-format json --log-interval 25 \
--fp16
If training on a single GPU, set --update-freq=4
to accumulate 4x gradients
and simulate training on 4 GPUs.
2. Evaluate
fairseq-eval-lm data-bin/wikitext-103/ \
--path checkpoints/checkpoint_best.pt \
--user-dir examples/truncated_bptt/ \
--task truncated_bptt_lm \
--batch-size 1 --required-batch-size-multiple 1 \
--model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \
--tokens-per-sample 64
# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537
# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s)
# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70
# Compare to 24.0 test perplexity from the paper
Note: During training the model saw 150 tokens of context
(--tokens-per-sample=150
) and 150 extra memory tokens (--mem-len=150
).
During evaluation we measure perplexity on sequences of 64 tokens
(--tokens-per-sample=64
) and increase the memory length
(--model-overrides='{"mem_len":640}'
). These settings match the evaluation
settings from the original
paper.