Figure 1: $M$ previous hidden states are saved to memory for each layer and input along with the representation of the current segment. Here shown with segment length $L=4$ and memory length $M=5$.
Augmenting the architecture in this way also enables significant speedups when evaluating the model autoregressively. To see why, note that it is possible to feed only the most recently generated token back into the model, since the hidden state of previous words have already been computed and saved to memory. According to the authors, this makes Transformer-XL up to 1800+ times faster on some tasks than a vanilla Transformer.
One of the primary characteristics of Transformers is permutation equivariance. This means that the order in which tokens are input to the model has no effect on the output (aside from the order). For instance, if inputting $(x,y,z)$ to a Transformer outputs $(\hat{x}, \hat{y}, \hat{z})$, then inputting $(z, y, x)$ will output $(\hat{z}, \hat{y}, \hat{x})$. In practical terms, this means that the location of a word in a segment has no bearing on the modelβs representation of that word.
To resolve this, Vaswani et al. introduced positional encodings4. A positional encoding for a word at position $l$ in a segment and with embedding $x^{(l)} \in \R^d$ is a vector $p^{(l)} \in \R^d$ in which the entries $p^{(l)}_1,β¦,p^{(l)}_d$ alternate between being functions of sine and cosine wave functions:
β¬β¬ p^{(l)}_i = \begin{cases} \sin(\frac{l}{1000^{2k/d}}) & \text{if } i = 2k \newline \cos(\frac{l}{1000^{2k/d}}) & \text{if } i = 2k+1 \end{cases} β¬β¬
The details of why such vectors can be used to represent positions is outside the scope of this post. However, the important thing to note is that the final embeddings input to a Transformer is simply the vector addition $x^{(l)} + p^{(l)}$. This essentially acts as a βlabelβ to let the model know where in the segment each word is placed.
While this works extremely well in the original Transformer, it unfortunately poses a problem for the use of the hidden state memory outlined in the previous section. To see why, consider that the $l$βth word of segments $\textbf{s}_{k-1}$ and $\textbf{s}_k$ share the same positional encoding. This means that the hidden states representing information of previous segments are functions of positional encodings that continually reset, i.e. $1,β¦,L,1,β¦,L,1β¦,L$. This makes it harder for the model to determine how words in the current segment relate to words in previous segments and thus presents an obstacle for the use of the hidden states.
To resolve this problem, Dai et al. propose to use relative positional encodings. These are identical to the sinusoid encodings proposed by Vaswani et al. but instead of letting $p^{(l)}$ represent the $l$βth position in a segment, it represents a distance of $l$ between two words.
To see how we might employ such encodings, recall that the pre-softmax attention score (omitting the $\frac{1}{\sqrt{d}}$ scaling factor) is $QK^\text{T}$. Remembering the query-key-value computation in $(1)$ and assuming regular positional encodings $P_Q \in \R^{L \times d}$ and $P_K \in \R^{M+L \times d}$ for the query and key vectors, respectively, we can write the attention score as
β¬β¬ (H W_Q + P_QW_Q)(\tilde{H} W_K + P_KW_K)^\text{T}, β¬β¬
where have dropped the superscript $n$ and subscript $k$ for notational clarity. We can then further decompose the above expression as follows:
β¬β¬ \underbrace{HW_Q W_K^\text{T}\tilde{H}^\text{T}}_{(a)} + \underbrace{HW_Q W_K^\text{T}P^\text{T}_K}_{(b)} + \underbrace{P_QW_Q W_K^\text{T}\tilde{H}^\text{T}}_{(c)} + \underbrace{P_QW_Q W_K^\text{T}P^\text{T}_K}_{(d)}. β¬β¬
The authors now propose the following change to the above expression, with $R$ being a matrix of relative positional encodings. (Note that this is not the notation used by Dai et al. but it captures the same idea.)
β¬β¬ \textcolor{lightgray}{\underbrace{HW_Q W_K^\text{T}\tilde{H}^\text{T}}_{(a)} +} \underbrace{\text{shift}(HW_Q W_R^\text{T}R^\text{T})}_{(b)} + \underbrace{UW_K^\text{T}\tilde{H}^\text{T}}_{(c)} + \underbrace{\text{shift}(TW_R^\text{T}R^\text{T})}_{(d)}. \tag{2} β¬β¬
First, notice that the transformed positional encodings $P_QW_Q$ for the query vectors have been replaced in $(c)$ and $(d)$ by $U$ and $T$. Both $U$ and $T$ are matrices in $\R^{L \times d}$ but have repeated rows (i.e. the rows are made up of parameters $u,t \in \R^d$ and are thus the same for every query vector5). This creates two complimentary effects on the attention mechanism: $(c)$ creates a global content bias that emphasizes certain words regardless of their location, while $(d)$ does the opposite by expressing a global positional bias that emphasizes certain locations regardless of the associated content.
Removing $P_QW_Q$ is the first step toward disregarding absolute positions. To see how the relative positional encodings work, first recall that the key vectors contain both the $L$ words of the current segment as well as the $M$ previous words. The matrix $R \in \R^{M+L \times d}$ contains sinusoid positional encodings with the order βflippedβ such that the first row represents the biggest distance of $M+L-1$ while the last row contains the smallest distance of $0$.
Figure 1: Combining $(b)$ and $(d)$ in a single matrix multiplication. Identical colors in the middle and right matrices do not represent equal values but only serve to more easily illustrate the idea of shifting.
For the first word in the current segment, the distance to the farthest word in memory is $M$ while for the last word it is $M+L-1$. To account for this, a βcirculantβ left-shift is applied to the rows of the matrix in $(b)$ and $(d)$. This ensures proper relative distances to previous words while the distance to subsequent words have their entries set to $0$6. This is illustrated in Figure 1 above. Also note that whereas positional encodings are only used prior to the first layer in a vanilla Transformer, Transformer-XL uses relative positional encodings in every layer.
We have now covered the two primary innovations of Transformer-XL: The hidden state memory and the relative positional encodings required to make the memory work. Virtually all other parts of the model are identical to what you would find in a vanilla Transformer. We can thus proceed to an implementation of the model, starting with the multi-head attention mechanism, which constitutes the bulk of the code.
Most of what you see above you would also find in an implementation of the multi-head attention module of a regular Transformer. The first lines that differ are the definitions of self.u
and self.t
. These are the $\R^d$ vectors comprising the $U$ and $T$ matrices in $(2)$.
The most important lines above are the following:
Here, ac
takes care of computing the sum of part $(a)$ and $(c)$ in $(2)$ using a single matrix multiplication. The sum of part $(b)$ and $(d)$ is computed similarly and stored in bd
after which the resulting $\R^{L \times d}$ matrix is shifted using the circulant_shift
method. This operation corresponds to the left-shift illustrated in the right-hand side of Figure 1. Note that we shift by -seg_len+1
or, equivalently, $L-1$. This ensures that each word has a distance of $0$ to itself, $1$ to its neighbor in memory, and so on. Since the attention score is masked in the forward
method, there is no need to zero out entries in circulant_shift
.
Most of the remaining code should be self-explanatory for anyone familiar with Transformers. The only thing left to cover is the implementation of the hidden states. We manage these in the TransformerXL
class and, more specifically, in the following methods:
Here, dec
refers to a decoder layer, which is a small wrapper around a multi-head attention layer followed by a position-wise feed-forward layer. As you can see, the inputs to each layer is continuously saved to memory, keeping only the $M$ (here represented by self.mem_len
) most recent inputs.
The code for training the model can be found in train.py
on GitHub. We train the model to sort unordered sequences of $n$ digits from 0-9. For example, using sequences of length 4, the sequence [4,5,3,0]
will be input to the model as [4,5,3,0,0,3,4]
. This will produce 7 next-digit predictions of which we care about (and optimize with respect to) the last 4. Note that we use causal masking, such that each prediction cannot incorporate information about subsequent digits in the sequence. For a simple task like this, we can safely set $M=0$ during training.
Once the model has been trained using train.py
, you can evaluate it using eval.py
. In this case, we set $M=2n - 1$ (i.e. the sequence length and all predictions but the last) and start by inputting only the unordered sequence, e.g. [1,8,8,2]
. The last prediction of this sequence will be the modelβs prediction of the first digit in the sorted sequence, here 1
(hopefully). Now, instead of inputting [1,8,8,2,1]
, we simply input 1
, since the representations of [1,8,8,2]
have already been saved to memory and donβt need to be recomputed. We repeat this until the model has produced $n$ predictions. For good measure, we also evaluate the model with $M=0$ by continually inputting the whole unordered sequence along with the growing list of predicted digits. We then compare the performance with and without memory augmentation.
Evaluating without memory...
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:45<00:00, 9.03s/it]
Achieved accuracy of 0.9732000231742859 in 45.156030893325806 seconds
Evaluating with memory...
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:14<00:00, 2.97s/it]
Achieved accuracy of 0.9732000231742859 in 14.850011825561523 seconds
Evaluating the model locally on a CPU produces the output above and shows a performance increase of more than 3x when using the modelβs memory. While significant, this improvement is, of course, far from the 1800x mentioned by the authors.
From Dai et al.1
The reason for this is likely that the performance increase is highly dependent on the βattention lengthβ (i.e. $M+L$), with smaller lengths resulting in less significant speedups. This is shown in the table above when compared to the work of Al-Rfou et al.7
The methods proposed in Transformer-XL are a simple but efficient way of widening the βreceptive fieldβ of a Transformerβs attention layers. If you would like to explore what the model is capable of when it is scaled up and trained on a large corpus of text, you can play around with the pretrained XLNet model available on HuggingFace. If you are more interested in exploring the nuts and bolts of the model architecture, I encourage you to check out the associated GitHub repo and give it a star if you find it helpful βοΈ
Dai, Zihang and Yang, Zhilin and Yang, Yiming and Carbonell, Jaime and Le, Quoc V. and Salakhutdinov, Ruslan (2019). Transformer-XL: Attentive Language Models Beyond a Fixed-Length ContextΒ ↩Β ↩2
Yang, Zhilin and Dai, Zihang and Yang, Yiming and Carbonell, Jaime and Salakhutdinov, Ruslan and Le, Quoc V (2020). XLNet: Generalized Autoregressive Pretraining for Language UnderstandingΒ ↩
Sepp Hochreiter and JΓΌrgen Schmidhuber (1997). Long short-term memoryΒ ↩
Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N. and Kaiser, Lukasz and Polosukhin, Illia (2017). Attention Is All You NeedΒ ↩
I use the name $u$ and $t$ where Dai et al. use $u$ and $v$. Since I use matrix notation, using $v$ for the vector and $V$ for the matrix would create ambiguity with the matrix $V$ of value vectors.Β ↩
Note that an entry of $0$ is different from the sinusoid encoding of a distance of $0$.Β ↩
Al-Rfou, Rami and Choe, Dokook and Constant, Noah and Guo, Mandy and Jones, Llion (2018). Character-Level Language Modeling with Deeper Self-AttentionΒ ↩