Draft 1, 5/26/24
Draft 2, 6/8/24
Recommended Background:
Linear algebra, Basic machine learning theory (loss, functional
approximation, generalization), multilayer perceptrons/neural
networks.
Key Takeaways:
Mathematical intuition of the attention mechanism and the transformer
architecture, in an informal, interpretable way.
Quick Intro:
Generative AI and LLMs have become the next big craze, especially among
people with a very limited understanding of how these models actually
work – however, that’s another topic on general ignorance and lack of
critical thinking in this world. From an ML perspective, the
mathematical fundamentals remain the same, as all these models are
modeled off of this mathematical intuition.
Not sure how attention and transformers are taught nowadays, but my mental image comes from a perspective of word embeddings. (Primarily influenced from one of my past professors Dr. Mohammad Zaki, and my work as a research data scientist working on topic models).
Preliminaries:
Perhaps the most important step to do any sort of ML work is to be able
to mathematically model the problem effectively. If you’re unfamiliar
with the NLP space, we need to somehow numerically represent words and
sentences to actually be used as data. The simple approach would be to
one-hot encode a vocabulary to use as the data itself, i.e., represent
each word as a standard basis vector in a unique dimension. However,
this obviously doesn’t encode a lot of the important ties and
information between words.
This was the motivation to a popular word embedding algorithm from about 2013, Word2Vec. The algorithm tries to learning vector representations of an entire vocabulary, with "semantic similarity" embedded into the space by using a distance or similarity measure between the vectors. I believe the approach used the autoencoder-esque model, which can be thought of as a feed forward neural network doing some nonlinear compression by an encoder to decoder approach, but if it wasn’t, ultimately some feed forward network is chopped at the end (prediction/softmax layer). This layer at the end defines a high dimensional distribution of word vectors, with numerically close (cosine similarity) vectors are considered "semantically similar". This is effectively done by training the model to predict the word given the context, and the context given a word (the paper calls these two approaches the Continuous Bag of Words model, CBOW, and the Skip-Gram model, respectively).
Minutiae about Word2Vec: I don’t really like having blog posts linking to other blog posts unless they’re worthwhile topics, but if there’s interest, I can talk more about the details of Word2Vec.
Motivation:
Word2Vec is a static model in the sense that there is only one vector
representation of each word, learned/trained off the context. Certainly
we can train the model consistently to learn new contexts to account for
data shift, but there will still be scenarios where words can have
multiple meanings and be semantically similar/distant depending on the
context. Consider the word bank, which can be used in a variety
of contexts: the river bank, banking a basketball, a financial
bank. The dimensionality of the vector space is likely to be
limited to learn all the possible word-to-word interactions, and thus if
these varietal contexts are equally distributed in the data, the model
will have to learn a vector space where the bank vector is
equidistant from river, basketball, financial (assuming these
are contextually unrelated). This motivates being able to learn some
dynamic, contextual representation of word vectors, where we
pay attention to the context.
Attention Mechanism:
Attention is just a new term to coin the well used mathematical idea of
averaging in the context of word vectors (think statistical expectation
and weighing by probabilities, or along the lines of adaboost if you’re
familiar).
(I believe the initial idea was introduced in the RNNSearch paper, where they learned an encoder-decoder model with a bidirectional RNN (think of an RNN learning sequential dependencies not only from the right, but also from the left). They use an encoder-decoder architecture for a translation task, where basically the encoder encodes 1-hot vectors representing words to a shared space to the decoder, where the decoder than transforms the words to corresponding translated words. They used a feedforward network (nonlinear) to learn attention weights, depending on the current context vector (hidden state), the previous word guessed, and the bidirectional hidden states concatenated together (previous hidden state, and the future hidden state from the reverse direction). The outputs I believe were then softmaxed to produce a valid probability distribution, and be used as the averaging weights to the next context vector)
In the paper introducing transformers (Attention Is All You Need), the idea remains similar: weigh a vector representation by some similarity measure on the context. First, here’s a high level overview: I’ll talk about the input, the transformations, and the output of what we call a self-attention block.
All parameters are user defined and not set in stone unless specified otherwise. The input to the self-attention is pretty standard to NLP: a block of token vectors \({\mathbf{x}}_i \in \mathbb{R}^{d_{model}}\) (paper uses 512), with a block length of \(L\). A token is some split component of a word via a preprocessing step, tokenization. Given that we have control of how we want to preprocess our data, tokenization takes words and often splits them into subwords with some base meaning. For example, we can take the word "prejudice" and tokenize it into "pre" and "judice". A block is just a user defined parameter designating the length of the context for a single token. For example, it could be the average or max length of the amount of subwords for a sentence in your corpus.
These inputs are projected (linearly) into 3 different vector spaces: Key, Query, and Value \(\in \mathbb{R}^{64}\). So for each \({\mathbf{x}}_i\) in the block, \({\mathbf{W}}_k{\mathbf{x}}_i \mapsto {\mathbf{q}}_i, {\mathbf{W}}_q{\mathbf{x}}_i \mapsto {\mathbf{k}}_i, {\mathbf{W}}_v{\mathbf{x}}_i \mapsto {\mathbf{v}}_i\), where \({\mathbf{W}}_k, {\mathbf{W}}_q \in \mathbb{R}^{d_k \times d_{model}}, {\mathbf{W}}_v \in \mathbb{R}^{d_v \times d_{model}}\).
The idea behind this is similar to what Word2Vec proposes: to learn some latent space where some distance or numerical similarity measure corresponds to semantic similarity. The key and query vectors function as the token representatives in this semantic space, where they use the dot-product as a measure of similarity. The query \({\mathbf{q}}_i\) can be thought of as the main context word vector (think about a block being a sliding window surrounding a main context word in a sentence, and we do this for each word), which we want to measure the similarity to the key vectors corresponding to the context (block).
So we do \(L\) dot products for each \({\mathbf{q}}_i\), which corresponds to a matrix product \(QK^{\top}\), \(Q,W \in \mathbb{R}^{L\times d_k}\), where \({\mathbf{q}}_i\) corresponds to row \(i\) of matrix \(Q\), and column \(j\) corresponds to \({\mathbf{k}}_j\). This results in an \(L \times L\) matrix, where a row corresponds to the current context word \({\mathbf{q}}_i\) similarity against each \({\mathbf{k}}_j\) across the columns. For numerical stability, they scale down the values by \(\sqrt{d_k}\), and then apply a row-wise softmax to normalize these similarity weights to a valid probability distribution. What we end up with is called the attention matrix. \[A = \mbox{softmax}_{row-wise}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)\].
If linear algebra isn’t your strong suite, basically for each query vector \({\mathbf{q}}_i\), we evaluated how much the surrounding context words (represented as key vectors \({\mathbf{k}}_j\)) contextually contribute or have some semantic similarity to the main context word (represented by a query vector \({\mathbf{q}}_i\)), a total of \(L\) similarity measurements (includes the measurement with its own key vector \({\mathbf{k}}_i\)). These pre-softmax attention values then are condensed to be within \([0, 1]\) via row-wise softmax (row-wise corresponding to each query in the block), so that we have a fractional/probabilitistic representation of the strength of each context.
For each row of attention values corresponding to a query, we then compute a weighted sum of the block’s value vectors \({\mathbf{k}}_j\), i.e., \({\mathbf{v}}_i' = \sum_{j=1}^{L}\alpha_{ij}{\mathbf{v}}_j\), weighed by the corresponding attention values. In matrix form, this give us: \[V' = Attention(Q, K, V) = \mbox{softmax}_{row-wise}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V = AV\] Where \(V \in \mathbb{R}^{L \times d_v}\) (so rows are value vectors \({\mathbf{v}}_j\)).
So each query vector has been "contextualized" by assigning more weight to key vectors in the context that it’s close to. The distinction to Word2Vec is apparent at this step: whereas Word2Vec generates a space to do similarity measurements on the "value vectors"/vectors themselves, Attention does the similarity measure in a separate key and query space, which they then weigh these similarity measures to the actual value vectors. We can expect to have a different contextualized value vector \({\mathbf{v}}_i'\) for a different tokens in the context block (for the same main word).
(In the original RNNSearch paper, they use what’s called "Additive Attention", where the attention values are computed from a learned feedforward neural network dependent on the bidirectional hidden states. Here we use "multiplicative" or dot-product attention).
Finally, the set of \(L\) value vectors are then projected back up to the dimensionality of \(d_{model}\), \(V'W_o \in \mathbb{R}^{L \times d_{model}}, W_o \in \mathbb{R}^{{\mathbf{d}}_v \times d_{model}}\), which effectively is the output of the attention head for the single case.
(The reason it’s called *self* attention is due to the query and key vectors coming from the same context. There are instances of attention where we can pass key and query vectors from another domain (e.g., in translation, representing english words, encoding 1-hot to english) to the attention module in another domain (value vectors corresponding to french words)).
Given that we understand attention for a "single" head, Multi-Headed Self-Attention (MHSA) consists of multiple attention modules within itself (the idea is similar to how CNNs have multiple masks/kernels to learn a surface). Words often have multiple and important but nuanced patterns of context. With only one attention head, we might only capture some information about the context dependencies of a word, thus motivating multiple heads. With more attention heads, we can a variety of contexts that words can be in. The model basically learns multiple (\(h\) heads) randomly initialized linear projection matrices for \(Q, K, \& V\) in parallel. \[\mbox{head}_i = (QW_q^i, KW_k^i, VW_v^i)\] This corresponds to \(h\) different attention matrices for an input block.
How does this affect the dimensionality with respect to the outputs? To account for this, they set \(d_v = 64\), and \(h = 8\), so that they can concatenate the value vectors/attention heads corresponding to different contexts \(V_i'\), where \(d_v * h = d_{model}\). Concatenating preserves the differences in the learned attention heads (compared to averaging, where subtle yet important outlier information can be lost). The final projection matrix \(W_o\) now projects the final concatenated heads to \(d_{model}, W_o \in \mathbb{R}^{(h*d_v) \times d_{model}}\), similar to the single headed case.
So while Word2Vec creates "static" semantic vectors learned off of context, the attention mechanism learns dynamic semantic vectors depending on not only the previously learned context, but the current block context. We can notice that a single head of attention does \(L^2\) dot-products which can be done in parallel efficiently, where the sequential order does not mathematically matter – this is an issue that will be addressed by the positional encoding of the broader transformer architecture.
Broader Transformer Architecture:
The MHSA is the key feature of the transformer, but there are a few
other details to complete the architecture.
1.) Residual Connections:
Residual connections are an architectual design between (skip) layers in
a neural network. They function as a simple to improve
connectivity/information flow across the model. Basically, data is
passed directly to the output of the next layer, on top of whatever was
passed into the current layer. For example, suppose our MHSA module
produces nothing worthwhile, maybe in the initial steps of training.
With a residual connection layer, "residual" or remaining data is still
passed to the next layer. This is usually a sum of the input to the
output vectors in the next layer. Thus, we don’t directly lose anything
if MHSA fails to contribute effectively. Residual connections also help
to improve issues with vanishing gradients for longer neural networks.
Backpropagation does powers of matrices to get the interactions of the
loss to specific weights via the chain rule, which amplifies the
dominant eigenvalue \(\lambda_1\) to
either go to \(\infty\) (exploding
gradients) or 0. Residual connections provide additional data to be
moved across the network to alleviate some of this.
From a numerical linear algebra perspective, this is similar to the power method/dominant eigenvalue/eigenvector methods. Let’s consider a matrix \({\mathbf{A}}\in \mathbb{R}^{n\times n}\) with a dominant eigenvalue \(\lambda_1\). If we take some random vector \({\mathbf{x}}\in \mathbb{R}^n\), with probability 1, \({\mathbf{x}}\) can be represented as a non-zero combination of the eigenvectors of \({\mathbf{A}}\), \({\mathbf{x}}= \sum_{i=1}^{n}\lambda_i{\mathbf{v}}_i, \lambda_i \neq 0, \forall i = 1, \dots, n\). Then \({\mathbf{A}}^k{\mathbf{x}}= \sum_{i=1}^{n}\lambda_i{\mathbf{A}}^k{\mathbf{v}}_i = \sum_{i=1}^{n}\lambda_i^k{\mathbf{v}}_i\), where \(\lambda_1^k\) becomes the dominating term. Notice how if \(\lambda_i < 1\), \(\lambda_i^k \rightarrow 0\), and if \(\lambda_i > 1\), \(\lambda_i^k \rightarrow \infty\), as \(k \rightarrow \infty\). This will actually bring us to normalizing, which is next.
2.) Layer Normalization:
Layer normalization is as it sounds – normalizing the neuron values in a
layer to be mathematically nice, with mean \(\mu = 0\) and standard deviation \(\sigma = 1\). This helps us keep the
gradients and transformations across layers to remain within the same
scale. Given that we tend to use one learning rate for all the weight
updates, keeping the layers numerically consistent tends to improve the
stability of the training procedure, getting to the optimal much
quicker. Note that the residual layer is first added to the output of
the feed fordward layer token-wise (element-wise) before
normalizing.
3.) Feed Forward Network:
The transformer integrates a standard feed forward neural network to the
output (after layer normalization & residual connections) of the
MHSA. This FFN is implemented token-wise (element-wise).
4.) Positional Encodings:
Recall that attention computes context weight by dot products, which
isn’t positionally dependent and can be done in parallel. However, the
positioning of a sentence can drastically affect its meaning – it’s an
important piece of data that would be beneficial to learn. Positional
encodings are introduced at the inputs of the transformer model, to
encode sequential or positional information numerically by adding some
trigonometric function values.
One way to think of the positional encoding is from a discrete/binary perspective. Think of adding a vector represented in binary for each token, to represent its position.
5.) Encoder-Decoder Architecture:
Finally, the encoder-decoder architecture joins two separate components
of MHSA+FFN w/ residual connections and layer normalization together.
The encoder architecture is pretty much the "transformer" we described,
however, the true "transformer" consists of this joint encoder-decoder
architecture. The decoder is very similar to the encoder portion, except
that the key and query values are passed from the outputs of the encoder
self-attention (so the decoder attention heads are not fully "self").
Secondly, the output tokens are masked/hidden at each "timestep" or up
to token \(t\), to ensure the decoder
only uses what it has predicted thus far, and the tokens it knows thus
far. Just to make this whole idea more concrete, think of the encoder
token block as an english sentence. We first contextualize them into a
dynamic embedding space. In the decoder section, we want to take these
embeddings and "translate", or rather pick a certain set of words in
another other domain language, e.g., French (which will have the highest
probability in this word distribution), or in the training step, learn
to maximize the next word in the true label sentence in French.