Understanding LongFormer’s Sliding Window Attention Mechanism

Ahmed Salem Elhady
7 min readJan 19, 2021

--

A detailed explanation of the LongFormer’s Attention Mechanism

In this article, we will explain the LongFormer’s attention mechanism [Paper]. The model tries to overcome the issue of limited input sequence length in the classical transformer models, namely BERT-like models, by suggesting a convolution-like architecture for the attention mechanism. They called the mechanism Sliding Window Attention. We discuss this mechanism in detail throughout the document. This article is a sequel to our investigation about long input with BERT.

Problems with Long Inputs in Classical Transformers

Classical transformers’ superior performance is highly dependent on the self-attention block they incorporate inside themselves. Such architecture showed significant improvements in different tasks, namely the infamous NLP language model BERT, which was the baseline for many SoTA models built on top of it like RoBERTa, SpanBERT, and Others. However, the nature of the Transformer architecture suffers from the limitation of the maximum length of input size to 512 tokens. We first understand the problem then explain how the LongFormer overcame this limitation.

Classical Transformers’ Attention Layers

The mechanism of self-attention layers in classical transformers uses the input sequence as keys, which are the sequence representations, and queries that can attend to these keys, hence the input sequence attends to itself. Figure 1 shows an example of a 5-token input sequence: red nodes are keys and the blue ones are queries. For each token in the queries to be able to attend all key nodes (they become fully connected), this requires a memory of quadratic, O(n²), complexity per attention layer. Such quadratic complexity makes it heavier for the model to tolerate long input sequences.

Figure 1. Attention Layer Representation

This type of attention layer is known as the full attention or quadratic attention layer. A good way to visualize the layer connectivity is to represent it as an n*n matrix as in Figure 2. Green entry at position (i,j) means that the i-th token can attend to token the j-th token. The memory requirement is O(green rows * green columns) which is O(n²) in such case.

Figure 2. Full n² attention

Long Inputs with BERT

To use BERT for long documents, what we usually do is to segment the document into k-overlapping segments, each of which is of length n≤max_seq_length, run each segment over BERT and combine the representations. However, this solution sacrifices attention to information across the segments and suffers from information sharing among segments [see this post for details].

In some cases, your output may depend on long-distance attention between the document tokens. Figure 3 shows a dummy representation for some dependency of token “pharetra” on “Ipsum”. Such long attention is not achievable in ordinary BERT-like models. The LongFormer suggests a solution that moves from local attention to global attention over the entire document without huge memory requirements.

Figure 3. Dummy long text to visualize long dependencies

Sliding Window Attention

What the LongFormer does is somehow like the convolution process. The issue that caused the memory requirements to be quadratic in classical attention layers is that it allows each query node to attend to its peer in the keys as well as all of the keys nodes, hence we end up with n attention weight per query node. What LongFormer does is defines a window of width W, such that the query node is allowed to attend to only its peer in the key nodes, and the key node’s immediate neighbors inside the window. Figure 4 shows an attention window of size 3, where node Q highlighted in green is allowed to attend to the peer key (the one in the middle) and its immediate neighbor on the left and the right (Window size/2 on both sides). They chose the window to include direct neighbors to the peer key based on the assumption that the most important information to the word is its local neighbors.

Figure 4. Sliding Attention Window of Size 3.

What LongFormer does is swipes this window over the query nodes in a way like convolutional filters do on images. This results in a memory reduction to O(n*W) which is the order of magnitude better than O(n²) when W<<n.

But here is a thing: didn’t use this window of attention costs the query node losing information from key nodes outside such window? If you look at the level of a single attention layer you might think so. But when we stack multiple layers on top of each other, what happens is that eventually, at higher layers, the query node gains attention information from far neighbors but in different representation way. Just like what convolutional layers exactly do to images! Figure 5 provides visual animation of what happens for two consecutive layers with a window size of 3.

Figure 5. Animated sliding window attention for two consecutive layers

On the level of a single attention layer, the query node (highlighted in green) attends to only its peer and its immediate neighbors. But at the second layer, the query attention node gains information from the second immediate neighbors through attending to the query nodes of the first layer (marked by the orange paths). Thus we end up with a conical structure for each token’s attention, at the very bottom is the local attentive nodes to the near neighbors, but at higher layers, the attention gains information from tokens far away from it (global attention).

The memory complexity per layer becomes O(n*W), a visualization of such requirement can be seen in Figure 6 for W=3. Green elements mark the elements we save and need per layer.

Figure 6. Sliding Window Attention with a window size of 3

Dilated Sliding Window Attention

It is fairly notable that for very long documents, it will require a lot of attention layers to cover long-distance global attention relations between tokens. This raises the memory requirements of the entire attention block to O(n*w*L), where n: is the input sequence length, w: window size, and L: number of attention layers in the block. To preserve the memory improvement of the sliding window while also preserving the long-distance attention relationships, we would like to have n*w*L << n². This can be improved by reducing L.

The number of layers is related to the neighbors the window covers: the more neighbors covered per layer, the fewer layers you will need. The authors came up with the idea of using dilated windows: instead of taking W consecutive neighbors, we take W alternating neighbors, as shown in Figure 7. Notice that the memory requirement per layer is still O(n*W) since we save attention weights for only W elements, thus no memory increase per layer. However, we would require fewer layers to cover larger spans of sequences.

Figure 7. Dilated Sliding Window Attention Layer

Although this idea helps memory conservation well, it suffers from missing local information at shallow attention layers. This can degrade the attention performance rather than help it, or even block the improvements. That’s why the authors of LongFormer, inspired by Conv Nets, suggested a combination between normal sliding window layers and dilated such that: inside an attention block of L attention layers, the first m of those are normal sliding window, and the remaining L-m layers are dilated sliding window layers. Think of it as shallow layers capture local attention information better, then the top layers try to move from local representations to global ones faster, hence reducing the overall amount of layers needed. They used this combination in their pre-trained model available on GitHub.

Global+Sliding Window

The last type of attention layers the paper suggests is tricky. Do not confuse its name with what we mentioned at the end of Dilated Sliding Window. This layer allows for certain tokens to attend to all nodes all the time. This is because, depending on the task, you may want the token to attend to all sequence tokens.

For example, if you are using the [CLS] token from BERT to classify the document, you may want for this specific node to attend to the entire sequence all the time to improve performance. The choice of such tokens (global tokens) is completely an engineering responsibility and gives an extra degree of freedom for improvements. Figure 8 shows the memory matrix for an example layer, where tokens whose rows/columns are marked in all green are the tokens selected to have global attention.

Figure 8. Global tokens + Sliding Window Custom Attention Layer

--

--

Ahmed Salem Elhady
Ahmed Salem Elhady

Written by Ahmed Salem Elhady

NLP - Applied Data Scientist || @ Microsoft. MSc and TA @Zewail City of Science, Technology, and Innovation.

Responses (1)