LLM(十)——Infini Transformer

Infini-attention不同于过去的Attention机制,每次在处理一个新的输入时都会重新计算整个序列的 Attention 权重,也就代表会将过去的 K, V 都丢弃。

而 infini-attention 则是将 K, V 都保存在压缩记忆体里面,这样可以有两个优点:

  • 处理较长(甚至无限)的文本、上下文比较有帮助
  • 可以减少复杂度。因为不需要一直重复计算,可以提升效率、减少计算资源消耗。

1. 缩放点积注意力

原始多头注意力中的单个头部根据输入段序列 计算其注意力上下文 ,如下所示。首先,它计算注意力查询、键和值状态:

注意力上下文被计算为所有其他值的加权平均值:

2. 压缩内存

记忆检索。在Infini-attention中,通过使用查询从内存中检索新内容

分别是非线性激活函数和归一化项。所有K的总和作为归一化项,并使用ELU + 1作为激活函数。

记忆更新。检索完成后,用新的 KV 条目更新记忆和归一化项,并获得下一个状态:

新的记忆状态 被传递到下一个段在每个注意力层中构建循环。可以视为一个空间转换,可以计算出一个权重,下述过程与Attention机制非常相似了:

除以是否可以看作是在对softmax函数做近似?

受到 Delta rule 的启发,将 Memory update 设计成

长期上下文注入。接下来透过一个可学习的门控标量来结合 本地注意力状态 和记忆检索内容:

并行计算数量的上下文状态,并将它们连接并投影以获得最终的注意力输出:

其中是可训练权重。

3. 与Transformer XL 的差别

虽说 XL 会储存 K, V,但实际上XL会将上下文一些丢弃掉,只 cache 最后一段的 K, V。 相反的,infini 则是会有完整的上下文纪录。

4.参考代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class InfiniAttention(nn.Module):
def __init__(self, seq_len: int, emb_dim: int,
d_head: int, n_head: int, n_segments: int,
is_causal: Optional[bool] = True, update: Optional[str] = 'linear',
use_rope: Optional[bool] = True, device: Optional[str] = 'cpu'):
super().__init__()

"""
Args:
seq_len: Sequence length of the inputs.
n_segments: Number of segments (must be divisible to seq_len).
n_head: Number of attention heads.
emb_dim: Embedding dimension of the input.
d_head: Embedding dimension of each head.
is_causal: Whether the model causal or not.
use_rope: Use Rotary Positional Embeddings or not. Default: True.
device: cuda or cpu.
"""
if update not in ['linear', 'delta']:
raise ValueError('Update takes only one of these parameters - linear, delta')

assert seq_len % n_segments == 0, 'seq_len must be divisible to n_segments'
assert emb_dim % n_head == 0, 'emb_dim must be divisible to n_head'

self.seq_len = seq_len
self.n_segments = n_segments
self.n_head = n_head
self.emb_dim = emb_dim
self.d_head = d_head
self.is_causal = is_causal
self.use_rope = use_rope
self.update = update
self.device = device

self.beta = nn.Parameter(torch.ones((1,), device=device)) # -> A learnable scalar from the paper.
self.q = nn.Linear(emb_dim, emb_dim, device=device)
self.k = nn.Linear(emb_dim, emb_dim, device=device)
self.v = nn.Linear(emb_dim, emb_dim, device=device)
self.o = nn.Linear(emb_dim, emb_dim, device=device)
self.elu = nn.ELU()
self.freq_cis = RoPE.compute_freq_cis(emb_dim, seq_len, 10000.0, device=device)
self.register_buffer('causal', torch.tril(torch.ones(seq_len // n_segments, seq_len // n_segments, device=device)))

def forward(self, x: torch.Tensor) -> torch.Tensor:

batch_size, _, _ = x.size()

#There was no guide for initialization for the parameters below, so I just initialize them fron zero.
memory = torch.zeros((self.n_head, self.d_head, self.d_head)).to(self.device)
z = torch.zeros((self.n_head, self.d_head, 1)).to(self.device)

query = self.q(x)
key = self.k(x)
value = self.v(x)

if self.use_rope:
query, key = RoPE.RoPE(self.freq_cis, query, key, self.device)

query = query.view(batch_size, self.n_head, self.n_segments, self.seq_len // self.n_segments, self.d_head)
key = key.view(batch_size, self.n_head, self.n_segments, self.seq_len // self.n_segments, self.d_head)
value = value.view(batch_size, self.n_head, self.n_segments, self.seq_len // self.n_segments, self.d_head)

output = []

for idx in range(self.n_segments):

sigma_q = self.elu(query[:, :, idx, :, :]) + 1.0
sigma_k = self.elu(key[:, :, idx, :, :]) + 1.0
A_mem = (sigma_q @ memory) / ((sigma_q @ z) + 1e-6) # Adding 1e-6 for preventing division to 0

A_dot = query[:, :, idx, :, :] @ key[:, :, idx, :, :].transpose(-2, -1)

if self.is_causal:
A_dot.masked_fill_(self.causal == 0, float('-inf'))

A_dot = F.softmax(A_dot / torch.sqrt(torch.tensor(self.d_head)), dim = -1)
A_dot = A_dot @ value[:, :, idx, :, :]

attention = (F.sigmoid(self.beta) * A_mem) + ((1 - F.sigmoid(self.beta)) * A_dot)

#Update
if self.update == 'linear':
memory = memory + (sigma_k.transpose(-2, -1) @ value[:, :, idx, :, :])
else:
delta = (sigma_k @ memory) / ((sigma_k @ z) + 1e-6)
memory = memory + (sigma_k.transpose(-2, -1) @ (value[:, :, idx, :, :] - delta))

z = z + sigma_k.sum(dim = -2, keepdim = True)

output.append(attention)

attention = torch.concat(output, dim = 2).view(batch_size, self.seq_len, self.emb_dim)
return self.o(attention)

参考

  1. Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
  2. Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention(Medium Blog)
  3. 【中配】不留上下文:通过Infini-Attention实现有效的无限上下文变换器 - Yannic Kilcher
  4. 代码实现

LLM(十)——Infini Transformer
https://mztchaoqun.com.cn/posts/D48_Infini _transformer/
作者
mztchaoqun
发布于
2024年11月22日
许可协议