../images/asukaGlasses.jpg

Mechanisms of Attention & Transformers

Duke Kwon

Draft 1, 5/26/24

Draft 2, 6/8/24

Recommended Background:
Linear algebra, Basic machine learning theory (loss, functional approximation, generalization), multilayer perceptrons/neural networks.

Key Takeaways:
Mathematical intuition of the attention mechanism and the transformer architecture, in an informal, interpretable way.

Quick Intro:
Generative AI and LLMs have become the next big craze, especially among people with a very limited understanding of how these models actually work – however, that’s another topic on general ignorance and lack of critical thinking in this world. From an ML perspective, the mathematical fundamentals remain the same, as all these models are modeled off of this mathematical intuition.

Not sure how attention and transformers are taught nowadays, but my mental image comes from a perspective of word embeddings. (Primarily influenced from one of my past professors Dr. Mohammad Zaki, and my work as a research data scientist working on topic models).

Preliminaries:
Perhaps the most important step to do any sort of ML work is to be able to mathematically model the problem effectively. If you’re unfamiliar with the NLP space, we need to somehow numerically represent words and sentences to actually be used as data. The simple approach would be to one-hot encode a vocabulary to use as the data itself, i.e., represent each word as a standard basis vector in a unique dimension. However, this obviously doesn’t encode a lot of the important ties and information between words.

This was the motivation to a popular word embedding algorithm from about 2013, Word2Vec. The algorithm tries to learning vector representations of an entire vocabulary, with "semantic similarity" embedded into the space by using a distance or similarity measure between the vectors. I believe the approach used the autoencoder-esque model, which can be thought of as a feed forward neural network doing some nonlinear compression by an encoder to decoder approach, but if it wasn’t, ultimately some feed forward network is chopped at the end (prediction/softmax layer). This layer at the end defines a high dimensional distribution of word vectors, with numerically close (cosine similarity) vectors are considered "semantically similar". This is effectively done by training the model to predict the word given the context, and the context given a word (the paper calls these two approaches the Continuous Bag of Words model, CBOW, and the Skip-Gram model, respectively).

Minutiae about Word2Vec: I don’t really like having blog posts linking to other blog posts unless they’re worthwhile topics, but if there’s interest, I can talk more about the details of Word2Vec.

Motivation:
Word2Vec is a static model in the sense that there is only one vector representation of each word, learned/trained off the context. Certainly we can train the model consistently to learn new contexts to account for data shift, but there will still be scenarios where words can have multiple meanings and be semantically similar/distant depending on the context. Consider the word bank, which can be used in a variety of contexts: the river bank, banking a basketball, a financial bank. The dimensionality of the vector space is likely to be limited to learn all the possible word-to-word interactions, and thus if these varietal contexts are equally distributed in the data, the model will have to learn a vector space where the bank vector is equidistant from river, basketball, financial (assuming these are contextually unrelated). This motivates being able to learn some dynamic, contextual representation of word vectors, where we pay attention to the context.

Attention Mechanism:
Attention is just a new term to coin the well used mathematical idea of averaging in the context of word vectors (think statistical expectation and weighing by probabilities, or along the lines of adaboost if you’re familiar).

(I believe the initial idea was introduced in the RNNSearch paper, where they learned an encoder-decoder model with a bidirectional RNN (think of an RNN learning sequential dependencies not only from the right, but also from the left). They use an encoder-decoder architecture for a translation task, where basically the encoder encodes 1-hot vectors representing words to a shared space to the decoder, where the decoder than transforms the words to corresponding translated words. They used a feedforward network (nonlinear) to learn attention weights, depending on the current context vector (hidden state), the previous word guessed, and the bidirectional hidden states concatenated together (previous hidden state, and the future hidden state from the reverse direction). The outputs I believe were then softmaxed to produce a valid probability distribution, and be used as the averaging weights to the next context vector)

In the paper introducing transformers (Attention Is All You Need), the idea remains similar: weigh a vector representation by some similarity measure on the context. First, here’s a high level overview: I’ll talk about the input, the transformations, and the output of what we call a self-attention block.

All parameters are user defined and not set in stone unless specified otherwise. The input to the self-attention is pretty standard to NLP: a block of token vectors \({\mathbf{x}}_i \in \mathbb{R}^{d_{model}}\) (paper uses 512), with a block length of \(L\). A token is some split component of a word via a preprocessing step, tokenization. Given that we have control of how we want to preprocess our data, tokenization takes words and often splits them into subwords with some base meaning. For example, we can take the word "prejudice" and tokenize it into "pre" and "judice". A block is just a user defined parameter designating the length of the context for a single token. For example, it could be the average or max length of the amount of subwords for a sentence in your corpus.

These inputs are projected (linearly) into 3 different vector spaces: Key, Query, and Value \(\in \mathbb{R}^{64}\). So for each \({\mathbf{x}}_i\) in the block, \({\mathbf{W}}_k{\mathbf{x}}_i \mapsto {\mathbf{q}}_i, {\mathbf{W}}_q{\mathbf{x}}_i \mapsto {\mathbf{k}}_i, {\mathbf{W}}_v{\mathbf{x}}_i \mapsto {\mathbf{v}}_i\), where \({\mathbf{W}}_k, {\mathbf{W}}_q \in \mathbb{R}^{d_k \times d_{model}}, {\mathbf{W}}_v \in \mathbb{R}^{d_v \times d_{model}}\).

The idea behind this is similar to what Word2Vec proposes: to learn some latent space where some distance or numerical similarity measure corresponds to semantic similarity. The key and query vectors function as the token representatives in this semantic space, where they use the dot-product as a measure of similarity. The query \({\mathbf{q}}_i\) can be thought of as the main context word vector (think about a block being a sliding window surrounding a main context word in a sentence, and we do this for each word), which we want to measure the similarity to the key vectors corresponding to the context (block).

So we do \(L\) dot products for each \({\mathbf{q}}_i\), which corresponds to a matrix product \(QK^{\top}\), \(Q,W \in \mathbb{R}^{L\times d_k}\), where \({\mathbf{q}}_i\) corresponds to row \(i\) of matrix \(Q\), and column \(j\) corresponds to \({\mathbf{k}}_j\). This results in an \(L \times L\) matrix, where a row corresponds to the current context word \({\mathbf{q}}_i\) similarity against each \({\mathbf{k}}_j\) across the columns. For numerical stability, they scale down the values by \(\sqrt{d_k}\), and then apply a row-wise softmax to normalize these similarity weights to a valid probability distribution. What we end up with is called the attention matrix. \[A = \mbox{softmax}_{row-wise}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)\].

If linear algebra isn’t your strong suite, basically for each query vector \({\mathbf{q}}_i\), we evaluated how much the surrounding context words (represented as key vectors \({\mathbf{k}}_j\)) contextually contribute or have some semantic similarity to the main context word (represented by a query vector \({\mathbf{q}}_i\)), a total of \(L\) similarity measurements (includes the measurement with its own key vector \({\mathbf{k}}_i\)). These pre-softmax attention values then are condensed to be within \([0, 1]\) via row-wise softmax (row-wise corresponding to each query in the block), so that we have a fractional/probabilitistic representation of the strength of each context.

For each row of attention values corresponding to a query, we then compute a weighted sum of the block’s value vectors \({\mathbf{k}}_j\), i.e., \({\mathbf{v}}_i' = \sum_{j=1}^{L}\alpha_{ij}{\mathbf{v}}_j\), weighed by the corresponding attention values. In matrix form, this give us: \[V' = Attention(Q, K, V) = \mbox{softmax}_{row-wise}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V = AV\] Where \(V \in \mathbb{R}^{L \times d_v}\) (so rows are value vectors \({\mathbf{v}}_j\)).

So each query vector has been "contextualized" by assigning more weight to key vectors in the context that it’s close to. The distinction to Word2Vec is apparent at this step: whereas Word2Vec generates a space to do similarity measurements on the "value vectors"/vectors themselves, Attention does the similarity measure in a separate key and query space, which they then weigh these similarity measures to the actual value vectors. We can expect to have a different contextualized value vector \({\mathbf{v}}_i'\) for a different tokens in the context block (for the same main word).

(In the original RNNSearch paper, they use what’s called "Additive Attention", where the attention values are computed from a learned feedforward neural network dependent on the bidirectional hidden states. Here we use "multiplicative" or dot-product attention).

Finally, the set of \(L\) value vectors are then projected back up to the dimensionality of \(d_{model}\), \(V'W_o \in \mathbb{R}^{L \times d_{model}}, W_o \in \mathbb{R}^{{\mathbf{d}}_v \times d_{model}}\), which effectively is the output of the attention head for the single case.

(The reason it’s called *self* attention is due to the query and key vectors coming from the same context. There are instances of attention where we can pass key and query vectors from another domain (e.g., in translation, representing english words, encoding 1-hot to english) to the attention module in another domain (value vectors corresponding to french words)).

Given that we understand attention for a "single" head, Multi-Headed Self-Attention (MHSA) consists of multiple attention modules within itself (the idea is similar to how CNNs have multiple masks/kernels to learn a surface). Words often have multiple and important but nuanced patterns of context. With only one attention head, we might only capture some information about the context dependencies of a word, thus motivating multiple heads. With more attention heads, we can a variety of contexts that words can be in. The model basically learns multiple (\(h\) heads) randomly initialized linear projection matrices for \(Q, K, \& V\) in parallel. \[\mbox{head}_i = (QW_q^i, KW_k^i, VW_v^i)\] This corresponds to \(h\) different attention matrices for an input block.

How does this affect the dimensionality with respect to the outputs? To account for this, they set \(d_v = 64\), and \(h = 8\), so that they can concatenate the value vectors/attention heads corresponding to different contexts \(V_i'\), where \(d_v * h = d_{model}\). Concatenating preserves the differences in the learned attention heads (compared to averaging, where subtle yet important outlier information can be lost). The final projection matrix \(W_o\) now projects the final concatenated heads to \(d_{model}, W_o \in \mathbb{R}^{(h*d_v) \times d_{model}}\), similar to the single headed case.

So while Word2Vec creates "static" semantic vectors learned off of context, the attention mechanism learns dynamic semantic vectors depending on not only the previously learned context, but the current block context. We can notice that a single head of attention does \(L^2\) dot-products which can be done in parallel efficiently, where the sequential order does not mathematically matter – this is an issue that will be addressed by the positional encoding of the broader transformer architecture.

Broader Transformer Architecture:
The MHSA is the key feature of the transformer, but there are a few other details to complete the architecture.

1.) Residual Connections:
Residual connections are an architectual design between (skip) layers in a neural network. They function as a simple to improve connectivity/information flow across the model. Basically, data is passed directly to the output of the next layer, on top of whatever was passed into the current layer. For example, suppose our MHSA module produces nothing worthwhile, maybe in the initial steps of training. With a residual connection layer, "residual" or remaining data is still passed to the next layer. This is usually a sum of the input to the output vectors in the next layer. Thus, we don’t directly lose anything if MHSA fails to contribute effectively. Residual connections also help to improve issues with vanishing gradients for longer neural networks. Backpropagation does powers of matrices to get the interactions of the loss to specific weights via the chain rule, which amplifies the dominant eigenvalue \(\lambda_1\) to either go to \(\infty\) (exploding gradients) or 0. Residual connections provide additional data to be moved across the network to alleviate some of this.

From a numerical linear algebra perspective, this is similar to the power method/dominant eigenvalue/eigenvector methods. Let’s consider a matrix \({\mathbf{A}}\in \mathbb{R}^{n\times n}\) with a dominant eigenvalue \(\lambda_1\). If we take some random vector \({\mathbf{x}}\in \mathbb{R}^n\), with probability 1, \({\mathbf{x}}\) can be represented as a non-zero combination of the eigenvectors of \({\mathbf{A}}\), \({\mathbf{x}}= \sum_{i=1}^{n}\lambda_i{\mathbf{v}}_i, \lambda_i \neq 0, \forall i = 1, \dots, n\). Then \({\mathbf{A}}^k{\mathbf{x}}= \sum_{i=1}^{n}\lambda_i{\mathbf{A}}^k{\mathbf{v}}_i = \sum_{i=1}^{n}\lambda_i^k{\mathbf{v}}_i\), where \(\lambda_1^k\) becomes the dominating term. Notice how if \(\lambda_i < 1\), \(\lambda_i^k \rightarrow 0\), and if \(\lambda_i > 1\), \(\lambda_i^k \rightarrow \infty\), as \(k \rightarrow \infty\). This will actually bring us to normalizing, which is next.

2.) Layer Normalization:
Layer normalization is as it sounds – normalizing the neuron values in a layer to be mathematically nice, with mean \(\mu = 0\) and standard deviation \(\sigma = 1\). This helps us keep the gradients and transformations across layers to remain within the same scale. Given that we tend to use one learning rate for all the weight updates, keeping the layers numerically consistent tends to improve the stability of the training procedure, getting to the optimal much quicker. Note that the residual layer is first added to the output of the feed fordward layer token-wise (element-wise) before normalizing.

3.) Feed Forward Network:
The transformer integrates a standard feed forward neural network to the output (after layer normalization & residual connections) of the MHSA. This FFN is implemented token-wise (element-wise).

4.) Positional Encodings:
Recall that attention computes context weight by dot products, which isn’t positionally dependent and can be done in parallel. However, the positioning of a sentence can drastically affect its meaning – it’s an important piece of data that would be beneficial to learn. Positional encodings are introduced at the inputs of the transformer model, to encode sequential or positional information numerically by adding some trigonometric function values.

One way to think of the positional encoding is from a discrete/binary perspective. Think of adding a vector represented in binary for each token, to represent its position.

5.) Encoder-Decoder Architecture:
Finally, the encoder-decoder architecture joins two separate components of MHSA+FFN w/ residual connections and layer normalization together. The encoder architecture is pretty much the "transformer" we described, however, the true "transformer" consists of this joint encoder-decoder architecture. The decoder is very similar to the encoder portion, except that the key and query values are passed from the outputs of the encoder self-attention (so the decoder attention heads are not fully "self"). Secondly, the output tokens are masked/hidden at each "timestep" or up to token \(t\), to ensure the decoder only uses what it has predicted thus far, and the tokens it knows thus far. Just to make this whole idea more concrete, think of the encoder token block as an english sentence. We first contextualize them into a dynamic embedding space. In the decoder section, we want to take these embeddings and "translate", or rather pick a certain set of words in another other domain language, e.g., French (which will have the highest probability in this word distribution), or in the training step, learn to maximize the next word in the true label sentence in French.