Notes on Transformer Architecture

Mar 20, 2024

Model Architecture

  • Transformer model consist of encoding and decoding component
  • Encoding component is a stack of six encoders
  • Decoding component is a stack of six decoders
  • Encoders has the identical structure
  • Encoder does NOT share the same weights
  • Each encoder has two layers within it. Self attention layer --> feed forward network.
  • Self attention layer helps in focus on the other words while encoding the current word
  • Each decoder has three layers. Self attention --> encoder-decoder attention --> feed forward network
  • Encoder-decoder attention helps in focus on relevant part of the input sequence to produce a next word ![[transformer.png]]

High-level encoder

  • Each word is embedded into a 512512 dimensional vector
  • Embedding only happens at the bottom most encoder
  • Each encoder receives list of vectors of size 512512
  • For the bottom most encoder, we give a list of embedding vectors. This list of embedding vectors represents a sentence with words represented as vectors.
  • Size of the list is dependent on the longest sentence in the training corpus.
  • In Encoder:
    • the number of vectors (x1,x2,...,xnx_1, x_2, ..., x_n) we feed to self attention layer, the same number of vector it produces as output (z1,z2,...znz_1, z_2, ... z_n)
    • then this output is fed through feed forward network to produce (r1,r2,...rnr_1, r_2, ... r_n)
  • Each word flows through its own path in the encoder
    • in self attention layer, these path are dependent on each other thus words cannot be processed in parallel
    • but in feed forward network these paths are independent and parallel processing of these words is possible

Encoder Internals

  • In RNN the hidden state vector is passed along the input sequence to store the meaning of the input sequence
  • At time step tt,
    • hidden state contains the meaning of words that occurred at previous time steps [1,t1][1, t-1]
    • at time step tt, model incorporates the meaning of the current word into the hidden state
    • this hidden state is supposed to learn the inter-word dependencies in the input sequence
    • but it does not perform that well
  • self attention is used to incorporate understanding of different words at different locations into the encoding of current word

Self attention in detail

First Step

  • create three vectors from each of the embedding vectors
    1. key vector (kk)
    2. query vector (qq)
    3. value vector (vv)
  • these vectors are created by multiplying by three matrices that we train during the training process
ki=ei×WKqi=ei×WQvi=ei×WV\begin{align} k_i &= e_i \times W^K \\ q_i &= e_i \times W^Q \\ v_i &= e_i \times W^V \\ \end{align}
  • eie_i is the embedding vector of 512512 dimensions
  • WKW^K , WQW^Q and WVW^V has shape of 512×64512 \times 64
  • this means shape of kik_i , qiq_i and viv_i is 6464

Second Step

  • calculate attention score for each word against the current word
  • attention score of word xjx_j for the word xix_i is calculated by dot product between key between key vector of xjx_j and query vector of xix_i
scoreij=qikjscore_{ij} = q_i \cdot k_j
  • scoreijscore_{ij} tells you that how much attention should be given to the word xjx_j for the encoding of word xix_i

Third Step

  • divide the attention scores by square root of the dimension of key vector
score_newij=scoreijdkscore\_new_{ij} = \frac{score_{ij}}{\sqrt{d_k}}
  • dkd_k dimensions in key vector
  • softmax function which is applied after this step can be sensitive to very large scores
  • this can kill off gradients
  • dk\sqrt{d_k} normalizes the scores

Fourth Step

  • apply softmax to all score_newijscore\_new_{ij}
softmax_scoreij=softmax(score_newij)=escore_newijkescore_newik\begin{align} softmax\_score_{ij} &= softmax(score\_new_{ij}) \\ \\ &= \frac{e^{score\_new_{ij}}}{\sum_k e^{score\_new_{ik}}} \end{align}

Fifth Step

  • create multiple new value vectors for the word xix_i by multiplying value vector vjv_j of each word xjx_j by the softmax_scoreijsoftmax\_score_{ij}
v_newij=softmax_scoreijvjv\_new_{ij} = softmax\_score_{ij} \cdot v_j
  • think of value vector as information about a word
  • by multiplying softmax score with a value vector we are determining how much information of a word should I need to encode a certain word

Sixth Step

  • Create new vector ziz_i which encodes the word xix_i by summing new value vectors created from each word in the sentence
zi=jv_newijz_i = \sum_j v\_new_{ij}
  • ziz_i encodes all the understanding from different word at different locations for the current word xix_i

Matrix Calculation of Attention

  • suppose we have a 4×5124 \times 512 matrix XX
    • sentence length is 44
    • embedding vector has 512512 dimensions
  • we have WKW^K, WQW^Q and WVW^V matrices of 512×64512 \times 64 dimensions
  • now we'll create key, query and value vectors of all words in a sentence
K=X×WKQ=X×WQV=X×WV\begin{align} K &= X \times W^K\\ Q &= X \times W^Q \\ V &= X \times W^V \end{align}
  • KK, QQ and VV each has dimensions of 4×644 \times 64
  • now find attention score of every word by multiplying query vector of a word with all the key vectors
A=Q×KTA = Q \times K^T
  • AA has 4×44 \times 4 dimensions
    • each ithi^{th} row represents the attention scores of all words against the ithi^{th} word in the sentence
  • now we want to divide every value by square root of number of dimensions in key vector
A=AdKA = \frac{A}{\sqrt{d^K}}
  • apply softmax on column wise on matrix AA to convert attention scores
S=softmax(A)S = softmax(A)
  • now we find ZZ matrix containing final encoding of each word by multiplying SS by value matrix VV
    • ZZ has 4×644 \times 64 dimensions
Z=S×VZ = S \times V
  • all the above operations can be written in one formula
Z=softmax(Q×KTdK)×VZ = softmax\left( \frac{Q \times K^T}{\sqrt{d^K}}\right) \times V

Multiheaded Attention

  • in single headed attention, zz vector contains little bit of every word encoding but it could be possible that a irrelevant word encoding may have gotten higher attention score than anyone else.
    • For example, consider sentence "plane crashed into the sea"
    • it is possible that attention score for word "crashed" against the word "sea" is higher compared to word "plane"
    • this means that model is thinking that it is "sea" that is "crashed" and not the "plane" which is meaningless
  • with multiheaded attention the possibility of domination of a certain word encoding in the zz vector is reduced
  • each head has it's own set of randomly initialized WQW^Q, WKW^K and WVW^V weight matrices
  • each set will project input embeddings into different subspaces
  • each head will then create different encoding for the same input embeddings
  • different encodings of the same input is like assigning multiple meaning to a sentence
  • each encoding will focus on different aspects of the same sentence
  • in the actual transformer model number of heads is 88 (this could be any other value)
  • this creates 88 encoding matrices Z1,Z2,...Z8Z_1, Z_2, ... Z_8
  • each Z1,Z2,...,Z8Z_1, Z_2, ..., Z_8 has dimensions 4×644 \times 64
  • but the feed forward network after this attention layer expects a single matrix
  • so we concatenate all ZZ matrices to create a single matrix of dimension 4×5124 \times 512 and multiply with another weight matrix WOW^O of dimension 512×512512 \times 512
Z=cat(Z1,Z2,...,Z8)×WOZ = cat(Z_1, Z_2, ..., Z_8) \times W^O
  • ZZ matrix has dimension 4×5124 \times 512
  • now this ZZ matrix goes through a feed forward network to create a new RR matrix which is then forwarded to the next encoder block

Efficient Multiheaded Attention

  • creating multiple heads of attention also comes with multiple sets of WQW^Q, WKW^K and WVW^V weight matrices
  • suppose kk is the number of dimensions in embedding vector of a word
  • therefore input XX will have shape sentence_len×ksentence\_len \times k
  • a single head in a multiheaded attention can have weight matrices of shape k×khk \times \frac{k}{h}
    • every head now has 3k2h3\frac{k^2}{h} parameters
    • hh number of heads will have h×3k2h=3k2h \times 3\frac{k^2}{h} = 3k^2 parameters
    • which is same as having a single head attention with weight matrices of shape k×kk \times k because it will also have 3k23k^2 parameters
  • each head ii produces a ZiZ_i matrix of shape sentence_len×khsentence\_len \times \frac{k}{h}
  • concatenating ZiZ_is gives another matric of shape sentence_len×ksentence\_len \times k
  • then we pass this concatenated matrix through feed forward network

Using Positional Encoding

  • in LSTM or classical RNN techniques model learns relative positions of words by itself because we process sentences sequentially
  • so during processing of a word, RNN model knows that at what position does word arrive at and what words it had processed before the current word
  • but in transformers since we are processing the whole sentence in parallel it becomes difficult to learn about relative positions of words within the sentence
  • that's why positional encodings are used to encode the position of a word into the embedding vector of the word
  • positional encodings are learned by model and follows a specific pattern at the end of training
  • positional encodings are added to the embedding vectors at the bottom most encoder layer
  • two approaches to initialize positional encodings:
    • fixed positional encodings: using sine and cosine functions
    • randomly initialized: positional encodings are learned by model itself but it is more computationally expensive
X=X+PX' = X + P
  • PP is positional embedding matrix of shape sentence_len×ksentence\_len \times k
  • XX is input matrix of shape sentence_len×ksentence\_len \times k
  • XX' is modified input matrix which also contains the positional information of each word embedding vector

Residual Connection

  • encoders and decoders are deep neural networks therefore there's risk of vanishing gradients
  • residual connection prevents vanishing gradient problem
  • layer normalization prevents inputs from being too small or too large which in turn improves stability
  • after each self-attention layer and feed forward network layer, there is residual connection
  • layer normalization is applied on residual connection
Z=self_attention(X)Z=layernorm(X+Z)R=FFN(Z)R=layernorm(R+Z)\begin{align} Z &= self\_attention(X) \\ Z' &= layernorm(X + Z) \\ R &= FFN(Z') \\ R' &= layernorm(R + Z') \end{align}
  • RR' is then fed to the next encoder block
  • similar structure is followed by each decoder block

Decoder Internals

Using target mask

  • during training, we don't want decoder to know about the next token in the target sequence
  • to do this we use target mask to zero out attention scores for future tokens
target_mask=[100110111]target\_mask = \begin{bmatrix} 1 && 0 && 0 \\ 1 && 1 && 0 \\ 1 && 1 && 1 \end{bmatrix}
  • now suppose that we have 3 words in the sentence and we calculate attention score as follows
attention_score=[521483769]attention\_score = \begin{bmatrix} 5 && 2 && 1 \\ 4 && 8 && 3 \\ 7 && 6 && 9 \end{bmatrix}
  • each ithi^{th} row represent the attention scores for ithi^{th} word
  • for the first word we should not know about attention score for second and third word i.e. future tokens so we zero out those values and similarly we do this for second and third word
attention=attention_scoretarget_markattention=[500480769]\begin{align} attention &= attention\_score \cdot target\_mark \\ \\ attention &= \begin{bmatrix} 5 && 0 && 0 \\ 4 && 8 && 0 \\ 7 && 6 && 9 \end{bmatrix} \end{align}
  • this ensures that ithi^{th} word decoding does not use information about future words at [i+1,i+2,...i+n][i+1, i+2, ... i+n] locations

Encoder-Decoder Attention Layer

  • this attention layer uses encodings produced by encoder layer to generate key and value matrices
  • so we have target word embeddings TT of shape 4×5124 \times 512 which means that target sentence contains only 44 words and encodings from encoder EE of shape 4×5124 \times 512
Q=T×WQK=E×WKV=E×WV\begin{align} Q &= T \times W^Q \\ K &= E \times W^K \\ V &= E \times W^V \\ \end{align}
  • WQW^Q , WKW^K and WVW^V has shape 512×64512 \times 64
  • QQ , KK and VV has shape 4×644 \times 64
  • then we calculate ZZ matrix and pass it through the layer normalization layer to the feed forward layer

After last decoder layer

  • now we have ZZ matrix of shape 4×644 \times 64
  • we want to find the next predicted word for each word in ZZ matrix
  • predicting next word is simply assigning a score to each word in the vocabulary and whoever has the highest score is our predicted next word
  • suppose we have 10001000 words in our vocabulary so we want to assign a score to each word
  • we do this passing ZZ through a linear layer which has output dimensions of 10001000
O=Z×WO = Z \times W
  • here WW has shape 64×100064 \times 1000 which generates output matrix OO of shape 4×10004 \times 1000
  • in OO, ithi^{th} row has 10001000 numbers and index at which maximum number is assigned is the index of our next predicted word for ithi^{th} word