https://colab.research.google.com/drive/1WTo_TL4V34CujADrNNJnZTpy7mo8uj8P#scrollTo=EDarxEWIRMKq
Attention is just a set of vectors that communicate in separate space. If you want a notion of space, you must add it.
the Query, Key, Value formulation allows the attention weights to depend on the input sequence
the variable wei
represents the attention weights. These weights specify how much attention each token should pay to every other token in the sequence.
The line:
tril = torch.tril(torch.ones(T, T))
generates a lower triangular matrix of size T by T (number of time steps). In the context of natural language processing, a lower triangular matrix is used for attention masking to ensure that the prediction for position i can depend only on the known outputs at positions less than i, not on the unknown outputs at positions i+1: T.
This is done to ensure that the model respects the temporal order of the data when making predictions. The attention mechanism will only look at past and present tokens, but not future tokens, which makes sense when you're trying to predict the next token in a sequence.
Example:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
[0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
[0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
[0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
grad_fn=<SelectBackward0>)
The eighth token in yellow in the eighth row of the sequence is telling us the position and context of the eighth token. The whole tensor in the eighth row shows the positional context relative to that eighth token