DeepSeek Conditional Memory:Engram

一、简介

LLM Memory 在 Agent 领域常被视为一种对历史时序状态进行压缩的手段,能够有效减少上下文长度。而 DeepSeek 最新提出的 Conditional-Memory Engram 建模,并非旨在压缩长上下文,而是一种短距离时序特征学习方法。下图中, 相邻 3 个 token 被映射成一个记忆 embedding。

Engram 的目标实则是增强模型的长上下文处理能力。Attention 机制与Engram 算法直观的对比:

  • Attention: query 会与历史上所有 key 计算注意力分数。例如在股票交易中,当前时刻的交易决策可能会参考过去所有时刻(如上市十年间的全部数据)的股价信息。
  • Engram:可看作一种长短时特征学习机制。仍以股票交易为例,短期内的标志性事件(如当日曝出企业财务造假新闻)所产生的影响是可预测的——在新闻发布前后,存在一段可被观测的隐藏短期时序模式。Engram 正是对这种短时序特征进行建模的方法。

因此,Engram 所捕捉的短时序特征,即 Conditional Memory,并非通常直观理解中的长程记忆处理方式。Attention 本身可通过注意力权重的分布自然体现长短时偏好,而 Engram 则在此基础上,显式地引入了一种短时序建模机制。

Engram 并非用于取代 Attention。Engram 对标的模块应当是类似 MoE(FFN) 一类的特征学习,引入记忆学习 则可以看成是一种参数 Scaling 或增加模型容量的手段。

二、Engram 框架

Engram 架构

在基础的attention和moe前,额外添加了一个Engram模块,该模块通过检索静态-gram记忆并将其通过上下文感知门控与动态隐藏状态融合来增强主干网络。该模块仅应用于特定层,以实现记忆与计算的解耦,同时保留标准的输入嵌入和解嵌入模块。

2.1 Engram架构拆解

Engram 是一种条件记忆模块。我们用一种简化的架构来描述这一特性。对 N-gram 进行压缩:

  • N-gram:指的是 token-id 序列里的一个窗口连续子序列, word2vec 使用中心词表示来预测周围的词来学习特征表示。我们定义一种 “left-N-gram” 来达到描述窗口历史子序列。
  • Encoding:对 N-gram 的序列编码成一串向量,如我们基于最简单的N-gram的 Token Embeeding 进行求和,本文简称这种编码方法为 Ngram2vec

Engram 的取名即是: Encoding N-gram

1
2
3
4
5
6
7
8
9
10
11
def get_ngram_vec(h, N):
"""get left-N-gram sum"""
ngram_vec = h.clone()
for i in range(1, N):
ngram_vec[:, i:] += h[:, :-i]
return ngram_vec

h = torch.tensor([[0,1,2,3,4]])
ngram_vec = get_ngram_vec(h, N=3)
print(h) # [[0, 1, 2, 3, 4]]
print(ngram_vec) # [[0, 1, 3, 6, 9]]

对于原有 token-level 的特征则可以与 Ngram2Vec 进行融合,达到引入外部信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
tokenizer = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
prompt = ['Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models']

# model
embd = nn.Embedding(backbone_config.vocab_size, backbone_config.hidden_size)
proj = nn.Linear(backbone_config.hidden_size, backbone_config.hidden_size)

# 获取序列 hidden
x = tokenizer(prompt, return_tensors = 'pt').input_ids # torch.Size([1, 21])
hidden = embd(x) # torch.Size([1, 21, 1024])
hidden_proj = proj(hidden) # torch.Size([1, 21, 1024])

# 获取 token-level ngram-vec
ngram_vec = get_ngram_vec(hidden, N=3) # torch.Size([1, 21, 1024])

# 特征与 ngram-vec 相加
hidden_engram = hidden_proj + ngram_vec # torch.Size([1, 21, 1024])

将 Engram 模块串接到 Decoder-Block 前,将特征与 Engram 的特征进行相加,输出的结果传输到注意力层前。根据原 paper 的描述隔 block 引入 Engram

1
2
3
4
5
6
attn = nn.Linear(backbone_config.hidden_size, backbone_config.hidden_size)

# forward
out_engram = hidden_engram + hidden # shortcut
out_attn = attn(out_engram) + out_engram # shortcut
# -> ffn -> next decoder-block

2.2 Engram 框架所解决的问题

Engram 用 N-Gram 规则统一定义短序列,实际上原模型窗口比较小,Engram 是围绕左窗口来做短序列建模的。可以将 Engram 概括为两个方面的优化

  1. Ngram2Vec
  2. Conditional Mmeory:如何利用条件记忆序列建模?

2.2.1 如何做 Ngram2Vec?

在上述提出了一种朴素的 Ngram2Vec,当有新知识引入时,会影响原本的知识,那么

  1. 如何引入独立的记忆层来学习新的内容? 如以下代码中的 self.ngram_embd
  2. 如何高效的学习的 Ngram2Vec ?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class LanguageModelWithNgram2Vec(nn.Module):
def __init__(self, config, N = 3):
super().__init__()
self.N = N
self.embd = nn.Embedding(config.vocab_size, config.hidden_size)
self.proj = nn.Linear(config.hidden_size, config.hidden_size)

self.ngram_embd = nn.Embedding(config.vocab_size, config.hidden_size)

def forward(self, x):
# 主分支
h = self.embd(x)
h_proj = self.proj(h)

# 记忆分支
h_ngram = self.ngram_embd(x)
ngram_vec = get_ngram_vec(h_ngram, self.N)

# 融合
h_engram = h_proj + ngram_vec
return h_engram

model = LanguageModelWithNgram2Vec(backbone_config, engram_cfg.max_ngram_size)
h_engram = model(x)
print(h_engram.shape)

2.2.2 如何做条件记忆序列建模?

我们定义 Ngram2Vec 是 short-memory(短距离记忆)。在因果注意力计算形式中,查询历史 KV 来完成序列交互,计算各时刻的特征,本质上是 History2Vec。可以归纳两种方式来描述外部记忆如何增强原有的特征表示

  • 交叉注意力:将 memory 作为 KV 进行序列交互,完成 token-feature 与 memory 交互,实现上可以采用交叉注意力完成。但这种方式显著增加了计算量。
  • 序列依赖式:将 memory 与特征相加融合,并引入类似 Conv1D 的方式进行短距离序列建模。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class EngramCrossAttn(LanguageModelWithNgram2Vec):
def __init__(self, config, N = 3):
super().__init__(config, N)
self.dim = config.hidden_size
self.Wq = nn.Linear(config.hidden_size, config.hidden_size)
self.Wk = nn.Linear(config.hidden_size, config.hidden_size)
self.Wv = nn.Linear(config.hidden_size, config.hidden_size)

def forward_(self, x):
# 主分支
h = self.embd(x)
h_proj = self.proj(h)

# 记忆分支
h_ngram = self.ngram_embd(x)
ngram_vec = get_ngram_vec(h_ngram, self.N)

# Cross Attention 融合
Q = self.Wq(h_proj)
K, V = self.Wk(ngram_vec), self.Wv(ngram_vec)
S = Q @ K.tranpose(1,2) /math.sqrt(self.dim)
P = F.softmax(S, dim =-1)
h_engram = P @ V

return h_engram

model = EngramCrossAttn(backbone_config, engram_cfg.max_ngram_size)
h_engram = model(x)
print(h_engram.shape)

至此,可以得到一个核心的概念:Conditional Memory。 对于一个 next-token 预测来说,增加了一部分条件。

engram 模块是伴随解码块的, ngram_embd 独立串接在 block 内,不是输入嵌入层只放置于解码器前。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class EngramBlock(nn.Module):
""" Decoder Block with Engram """
def __init__(self, config, N = 3)
super().__init__()
self.dim = config.hidden_size
self.engram = EngramCrossAttn(config, N = 3)
self.attn = nn.Linear(self.dim, self.dim)
self.ffn = nn.Linear(self.dim, self.dim)

def forward(self, h):
h = h + self.engram(h)
h = h + self.attn(h)
h = h + self.ffn(h)
return h

三、Engram: Ngram2Vec

为了实现 Ngram2Vec, 有两种简单的学习方式:

  1. 将 N-gram 对应多个Embedding 进行统计,如求和或平均;
  2. Engram 的做法是将 N-Gram token-ID 进行映射成 1 个 ID,并在一个额外的 Embedding 来放置 Ngram 的特征向量,这个额外的 Embedding 则为 Memory Embedding。

以下代码中只要实现好 mapping 函数,即可确定性的达到 ID 值映射的目的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class NgramMapping(nn.Module):
def __init__(self, vocab_size=5, dim = 128, N = 3, max_ngram_vocab_size = 1000):
super().__init__()
self.N = N
self.dim = dim
self.embd = nn.Embedding(vocab_size, dim)
self.proj = nn.Linear(dim, dim)

# 用于 Ngram-ID 的 embedding
self.ngram_embd = nn.Embedding(max_ngram_vocab_size, dim)

def mapping(self, x, N):
"""恒等映射, 不处理 gram"""
return x

def forward(self, x):
# 主分支
h = self.embd(x)
h_proj = self.proj(h)

# 记忆分支
# 从原始 ID 映射为 新 ID
x_ = self.mapping(x, self.N) # id -> new_id
# 从原始 ID 取出一个新 embd
h_ngram = self.ngram_embd(x_)

# 融合
h_engram = h_proj + h_ngram
return h_engram

3.1 N-gram 直接映射

一种直接的映射是将 N-Gram Token-ID 进行相加或相乘,在新词表中创建出一个新词。这种方法映射后的词量是巨大的,以一个 128k 的词表来说,其词表需要新创建 个 embedding,用于存储 N-Gram 所有可能出现的有序组合。

例如,一个很小的词表 vocab_size=5, 当 N=3 时, 有组合 5^3=125

以下代码中将 N-gram 进行乘积映射, 输入token序列取值范围为,取值范围在 之间。这种映射规则产生的问题是:

  • 当 3-gram 中, 有数据为 0 时, 都可能冲突的映射到 new_id = 0
  • 当 3-gram 中, token 序列 都指向 new_id = 18
  • 另外需要申请一个大的 ngram_embd, 当原始 vocab_size = 100 时, max_ngram_vocab_size 应为 , 它的词嵌入矩阵的向量条目是千万量级的规模。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class NgramMappingMul(nn.Module):
def __init__(self, vocab_size=5, dim = 128, N = 3, max_ngram_vocab_size = 1000):
super().__init__()
self.N = N
self.dim = dim
self.embd = nn.Embedding(vocab_size, dim)
self.proj = nn.Linear(dim, dim)

# 用于 Ngram-ID 的 embedding
self.ngram_embd = nn.Embedding(max_ngram_vocab_size, dim)

def mul_mapping(self, x, N):
"""乘积映射"""
x_ = x.clone()
for i in range(N):
x_[:, i:] *= (x_[:, :-i]+1) # 避免id=0产生冲突
return x_

def forward(self, x):
# 主分支
h = self.embd(x)
h_proj = self.proj(h)

# 记忆分支
# 从原始 ID 映射为 新 ID
x_ = self.mul_mapping(x, self.N) # id -> new_id
# 从原始 ID 取出一个新 embd
h_ngram = self.ngram_embd(x_)

# 融合
h_engram = h_proj + h_ngram
return h_engram

3.2 N-gram 哈希映射

我们可以设计一种哈希映射,具体的新创建一个大小为的Embedding 矩阵,然后对于 N-Gram 相加的 token-id 乘积结果取 mod,便可以达到快速的查找的目的。连乘取模是一种基础的哈希函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class NgramMappingHash(nn.Module):
def __init__(self, vocab_size=5, dim = 128, N = 3, max_ngram_vocab_size = 1000):
super().__init__()
self.N = N
self.dim = dim
self.embd = nn.Embedding(vocab_size, dim)
self.proj = nn.Linear(dim, dim)

# 用于 Ngram-ID 的 embedding
self.ngram_embd = nn.Embedding(max_ngram_vocab_size, dim)

# 增加模值
self.mod = max_ngram_vocab_size

def hash_mapping(self, x, N):
"""哈希映射"""
x_ = x.clone()
for i in range(1, N):
x_[:, i:] *= (x_[:, :-i]+1)
print('mul id:', x_.tolist())
# 增加模值
x_ = x_ % self.mod
return x_

def forward(self, x):
# 主分支
h = self.embd(x)
h_proj = self.proj(h)

# 记忆分支
# 从原始 ID 映射为 新 ID
x_ = self.hash_mapping(x, self.N) # id -> new_id
# 从原始 ID 取出一个新 embd
h_ngram = self.ngram_embd(x_)

# 融合
h_engram = h_proj + h_ngram
return h_engram, x_

vocab_size = 100
x = torch.randint(vocab_size, (1, 5), dtype=torch.long)
print('original id:', x.tolist())
ngram2vec = NgramMappingHash(vocab_size=100,
dim = 128,
N = 3,
max_ngram_vocab_size=10007) # 模值为素数
h_engram, x_ = ngram2vec(x)
print('hash id', x_.tolist())

哈希映射相较之前的映射方法,存储量可控。 哈希映射存在冲突问题不同的 N-Gram 可能映射到统同一个 Embedding。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def hash_mapping(x, N, mod=37):
"""哈希映射"""
_, L = x.shape
x_ = x.clone()
for i in range(1, N):
x_[:, i:] *= (x_[:, :-i]+1)
# 增加模值
x_ = x_ % mod

hash_table = torch.zeros(mod)
for i in range(L):
hash_table[x_[0, i]] += 1
if hash_table[x_[0, i]] > 1:
# TODO: process conflit
continue
return x_, hash_table

打印结果如下,哈希表数字大于1,说明有两个 ID 哈希到同一个位置了

1
2
3
4
5
6
7
8
9
10
x = torch.arange(37).unsqueeze(dim=0)
new_id, hash_table = hash_mapping(x, N=3, mod=37)
print(new_id)
print(hash_table)

tensor([[ 0, 1, 4, 18, 6, 28, 20, 16, 0, 17, 25, 6, 3, 9, 4, 29, 1, 9,
18, 17, 19, 24, 19, 15, 10, 26, 35, 33, 3, 26, 22, 9]])
tensor([2., 2., 0., 2., 2., 0., 2., 0., 0., 3., 1., 0., 0., 0., 0., 1., 1., 2.,
2., 2., 1., 0., 1., 0., 1., 1., 2., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
0.])

对于冲突有两种处理方式:

  1. 不解决冲突:多个数据将可能映射到同一个Hash ID。可以想象成哈希表中每个 ID 对应一个桶,桶里是可以容纳多个原数据的。常见的有:链地址法
  2. 解决冲突:如果映射的 Hash ID 已有数据, 那么通过新的策略重新 hash。可以想象成哈希表中每个 ID 至多被一个数据映射到。例如常见的有:开放寻址法

如下面代码,遇到冲突后进行线性试探,直到找到一个未映射过的空位。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def hash_mapping(x, N, mod=37):
"""哈希映射"""
_, L = x.shape
x_ = x.clone()
for i in range(1, N):
x_[:, i:] *= (x_[:, :-i]+1)
# 增加模值
x_ = x_ % mod

hash_table = torch.zeros(mod)
for i in range(L):
idx = x_[0, i]
if hash_table[idx] == 0:
hash_table[idx] += 1
else:
# confict
while hash_table[idx] == 1:
idx = (idx+1) # 线性探测
if idx >= max_vocab_size:
idx = 0
hash_table[idx] += 1
return x_, hash_table

x = torch.arange(37).unsqueeze(dim=0)
new_id, hash_table = hash_mapping(x, N=3, mod=37)
print(new_id)
print(hash_table)

# tensor([[ 0, 1, 4, 18, 6, 28, 20, 16, 0, 17, 25, 6, 3, 9, 4, 29, 1, 9,
18, 17, 19, 24, 19, 15, 10, 26, 35, 33, 3, 26, 22, 9, 29, 0, 12, 31,
10]])
# tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1.])

至此,我们期望的 N-gram 压缩的哈希映射为:

  • 新 Embedding 最大存储可控
  • 冲突少
  • 查询次数少(频繁的处理冲突,影响查找效率)
  • 如果哈希表是不用解决冲突的,那么哈希表中的每个桶里被映射的数据量是均衡的。

3.3 Multiplicative-XOR Hash(乘异或哈希)

官方实现代码中, 类 class::NgramHashMapping 进行实现 XOR 哈希, 建议按照以下内容逐步理解源码。

3.3.1 XOR 哈希

定义哈希函数 hash(key) = (key ^ (key >> 4)) % table_size ,观察其二进制数值。

  • 给定数值 42, 二进制为 00101010
  • 操作key >> 4 :2, 二进制为 00000010, 将key的二进制右移4位,左边补0
  • 进行key ^ (key >> 4) , 其中 ^ 为异或(XOR)操作 00101010 ^ 00000010 得到 00101000
  • 最终哈希值: 40 % 100 = 40,其中 100 是哈希表大小。

代码实现如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class XORHash:
def __init__(self, size=100):
self.size = size
self.table = [[] for _ in range(size)] # 每个桶存储列表, 不解决冲突

def hash(self, key):
return (key ^ (key >> 4)) % self.size

def insert(self, key):
idx = self.hash(key)
for i, k in enumerate(self.table[idx]):
if k == key:
self.table[idx][i] = key
return
self.table[idx].append(key)

def get(self, key):
idx = self.hash(key)
for k in self.table[idx]:
if k == key:
return k
return None

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
h = XORHash(50)
data = [random.randint(1, 128) for _ in range(128)]

for i, num in enumerate(data):
h.insert(num)
print(h.table)

print(f"插入 {len(data)} 个元素")
print(f"查找示例:")
for key in [data[0], data[10], data[20], 999]:
k = h.get(key)
if k == None:
print(f"{key} 未找到")
else:
print(f"{key}")

结果如下,哈希表中的元素是个列表。哈希方法可以不解决冲突。 其哈希表中理想情况下,各个列表长度均等。

1
2
3
4
5
6
7
[[98, 49], [1, 99], [55], [54, 97, 3], [4], [5, 52, 111], [108, 59, 6], [109], [106, 8], [9, 107], [10, 63, 104], [105], [119, 12], [60], [14, 68], [116, 69], [70, 17], [71, 114], [113, 64], [112, 18, 65], [], [126], [125, 23], [124, 77], [25], [122, 24], [27], [73], [29, 74], [], [], [30, 84], [87, 34], [35, 86], [81, 32], [80, 33], [128, 83, 38], [], [93, 36], [92, 37], [42, 95], [94], [40, 89], [88, 41], [91, 46], [47, 90], [44, 102], [45], [51, 100], []]
插入 128 个元素
查找示例:
119
125
40
999 未找到

3.3.2 Multiplicative-XOR Hash 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import random

class MultiplicativeXORHashNgram3:
def __init__(self, size=200):
self.size = size # 哈希表大小
self.table = [[] for _ in range(size)]
# 魔法数字, 通常为较大的素数
self.multipliers = [11400714819323198485 % (2**32), 2654435761, 179424673]

def hash(self, x):
hash_val = 0
for i in range(3):
temp = x[i] * self.multipliers[i]
temp ^= temp >> 16
hash_val ^= temp
return hash_val % self.size

def insert(self, x):
idx = self.hash(x)
for i, k in enumerate(self.table[idx]):
if k == x:
self.table[idx][i] = x
return False
self.table[idx].append(x)
return len(self.table[idx]) > 1

def get(self, x):
idx = self.hash(x)
for k in self.table[idx]:
if k == x:
return idx
return None

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
h = MultiplicativeXORHashNgram3(50)
data = [[random.randint(1, 128) for _ in range(3)] for _ in range(128)]

conflicts = 0
for point in data:
if h.insert(point):
conflicts += 1
print(f"Conflicts Ratio: {conflicts/len(data)*100:.1f}%")

# 测试
test_points = [data[0], data[10], [999, 999, 999]]
for point in test_points:
result = h.get(point)
status = "find" if result else "not exist"
print(f"3-gram {point}: {status}, hash_id: {result}")

输出

1
2
3
4
Conflicts Ratio: 64.1%
3-gram [93, 128, 33]: find, hash_id: 10
3-gram [94, 93, 107]: find, hash_id: 16
3-gram [999, 999, 999]: not exist, hash_id: None

3.3.3 DeepSeek Multiplicative-XOR Hash 实现

用 Multiplicative-XOR Hash 的处理方法优势在于

  • 可以处理 N-gram 数据输入
  • 映射后的哈希表中的列表大小均衡

根据前置内容,容易理解官方的哈希实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class NgramHashMapping:
# ....
def _get_ngram_hashes(
self,
input_ids: np.ndarray,
layer_id: int,
) -> np.ndarray:
x = np.asarray(input_ids, dtype=np.int64)
B, T = x.shape

# 乘子
multipliers = self.layer_multipliers[layer_id]

# 将序列数据构成 N-gram 列表
def shift_k(k: int) -> np.ndarray:
if k == 0: return x
shifted = np.pad(x, ((0, 0), (k, 0)),
mode='constant', constant_values=self.pad_id)[:, :T]
return shifted
base_shifts = [shift_k(k) for k in range(self.max_ngram_size)]

all_hashes = []
# deepseek 获取 2-gram 和 3-gram
for n in range(2, self.max_ngram_size + 1):
n_gram_index = n - 2
tokens = base_shifts[:n]

# 乘子
mix = (tokens[0] * multipliers[0])
# n-gram 逐元素计算
for k in range(1, n):
mix = np.bitwise_xor(mix, tokens[k] * multipliers[k]) # XOR
num_heads_for_this_ngram = self.n_head_per_ngram
head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index]

for j in range(num_heads_for_this_ngram):
mod = int(head_vocab_sizes[j])
head_hash = mix % mod # 取模
all_hashes.append(head_hash.astype(np.int64, copy=False))

return np.stack(all_hashes, axis=2)

3.4 N-gram: 多哈希映射

为了丰富外部获取的 embedding,可以进行多哈希映射,示例一个简单的映射规则,不同的哈希函数取不同的模,可以达到多哈希映射,丰富特征。

1
2
3
4
5
x = torch.arange(20).unsqueeze(dim=0) 
new_id, hash_table_1 = hash_mapping(x, N=3, mod=37, max_vocab_size=40)
print(new_id)
new_id, hash_table_2 = hash_mapping(x, N=3, mod=33, max_vocab_size=40)
print(new_id)

同一份序列有不同的数据映射

1
2
3
4
tensor([[ 0,  1,  4, 18,  6, 28, 20, 16,  0, 17, 25,  6,  3,  9,  4, 29,  1,  9,
18, 17]])
tensor([[ 0, 1, 4, 18, 14, 19, 18, 20, 25, 24, 32, 22, 24, 26, 7, 3, 8, 7,
9, 14]])

创建多分支对应的 embd。在单哈希映射中我们可以设定嵌入维度为,而个哈希映射中,可以创建多个嵌入矩阵,每个嵌入矩阵维度为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# original single-hash
ngram_embd = nn.Embedding(100, 256)

# multi-hash, multi embedding
ngram_embd_1 = nn.Embedding(100, 128)
ngram_embd_2 = nn.Embedding(100, 128)

x = torch.arange(20).unsqueeze(dim=0)
new_id_1, hash_table_1 = hash_mapping(x, N=3, mod=37, max_vocab_size=40)
print(new_id_1)
new_id_2, hash_table_2 = hash_mapping(x, N=3, mod=33, max_vocab_size=40)
print(new_id_2)

h_branch_1 = ngram_embd_1(new_id_1)
h_branch_2 = ngram_embd_2(new_id_2)
print(h_branch_1[0, 2, :5]) # id:4
print(h_branch_2[0, 2, :5]) # id:4

输出如下,统一个映射 ID 对应的 embedding 向量不同。

1
2
3
4
5
6
tensor([[ 0,  1,  4, 18,  6, 28, 20, 16,  0, 17, 25,  6,  3,  9,  4, 29,  1,  9,
18, 17]])
tensor([[ 0, 1, 4, 18, 14, 19, 18, 20, 25, 24, 32, 22, 24, 26, 7, 3, 8, 7,
9, 14]])
tensor([ 0.7217, -1.0902, 0.7087, 1.0136, 0.6399], grad_fn=<SliceBackward0>)
tensor([-1.0131, 0.0814, 0.0663, -0.8880, 0.0722], grad_fn=<SliceBackward0>)

对于多头哈希得到的 embedding 也可以进一步拼接。 如维向量拼接后得到原维度。 多哈希可以较好缓解不冲突哈希的特征单一来源的问题。

3.5 Tokenizer 词表压缩

上述 N-gram 的映射复杂度与原本词表大小正相关。

压缩词表大小能够减轻映射目标词嵌入矩阵的大小。具体的 Tokenizer 词表可以进行压缩,对于每个 token,采用一定的规范化函数,如将单词统一改成小写或取消标点,示例如下

Apple vs. ␣apple,

新的词表能减少同义重复词,此过程也可能产生新词。 代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class CompressedTokenizer:
def __init__(
self,
tokenizer_name_or_path,
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)

SENTINEL = "\uE000"
self.normalizer = normalizers.Sequence([
normalizers.NFKC(),
normalizers.NFD(),
normalizers.StripAccents(),
normalizers.Lowercase(),
normalizers.Replace(Regex(r"[ \t\r\n]+"), " "),
normalizers.Replace(Regex(r"^ $"), SENTINEL),
normalizers.Strip(),
normalizers.Replace(SENTINEL, " "),
])

self.lookup_table, self.num_new_token, self.overlap_num = self._build_lookup_table()

def __len__(self):
return self.num_new_token

def _build_lookup_table(self):
old2new = {}
key2new = {}
new_tokens = []

vocab_size = len(self.tokenizer)
count = 0
for tid in range(vocab_size):
text = self.tokenizer.decode([tid], skip_special_tokens=False)

if "�" in text:
key = self.tokenizer.convert_ids_to_tokens(tid)
else:
norm = self.normalizer.normalize_str(text)
key = norm if norm else text

nid = key2new.get(key)
if nid is None:
# print(key)
nid = len(new_tokens)
key2new[key] = nid
new_tokens.append(key)
else:
count+=1 # 记录重复token
old2new[tid] = nid # 映射

lookup = np.empty(vocab_size, dtype=np.int64)
for tid in range(vocab_size):
lookup[tid] = old2new[tid] # 多个token,可能变换到唯一 token ID

return lookup, len(new_tokens), count

def _compress(self, input_ids):
arr = np.asarray(input_ids, dtype=np.int64)
pos_mask = arr >= 0
out = arr.copy()
valid_ids = arr[pos_mask]
out[pos_mask] = self.lookup_table[valid_ids]
return out

def __call__(self, input_ids):
return self._compress(input_ids)

代入 DeepSeek-V3 的 tokenizer, 得到如下词表:

1
2
3
tokenizer = AutoTokenizer.from_pretrained("./v3", trust_remote_code=True)
tokenizer_c = CompressedTokenizer("./v3")
print(self.overlap_num) # 30188

即是原词表有 30188 个规范化前后一致的 token。 新 tokenizer 词表量不一定比原来 tokenizer 词表量小。

测试得到两个分词器有不同的 token_id 输出

1
2
3
4
5
6
7
prompt = 'i Love llm, I love xiaodonggua'
prompt_ids = tokenizer(prompt)['input_ids']
c_prompt_ids = tokenizer_c(prompt_ids)
print(len(prompt_ids)) # 13
print(len(c_prompt_ids)) # 13
print(prompt_ids[:10]) # [0, 75, 14920, 18200, 79, 14, 342, 3518, 1527, 601]
print(c_prompt_ids[:10]) # [ 0 43 2978 718 47 14 43 2978 58 494]

3.6 小结

  • Ngram2Vec 可以用多种方法获取,本文采用 N-gram 哈希映射成新 ID 来取新 embedding 作为 N-gram-wise 的 token-level memory
  • Ngram 输入是 N-维长度序列,DeepSeek 官方精心设计了 Multiplicative-XOR Hash ,在工程上要保证哈希后冲突率尽可能小

四、Engram: Conditional Memory

网络的数据流中有两类数据源,其 tensor 维度一致

  • Hidden State:token-level 级别的特征。如 attention 层输出为注意力特征,ffn 层输出为任务特征。自回归语言模型,每个token最终的 hidden state 都称之为 “next-token-prediction” 特征。一个网络数据流中,特征都是在输入 token-id 的 context 序列维度或特征维度进行交互而得。
  • Memory Embedding:N-gram 映射到外部取得的 embedding。

要实现两组不同数据之间的交互,一种自然的方式是 Cross Attention。Engram 实际上并未采用 Cross Attention 方案。而是在 token-level 级别上的 Hidden State 和 Memory Embedding 进行融合。

标准的Next-token-prediction 任务为:

Engram 引入了外部的 Memory Embedding,预测概率为:

其中,为 context 下的所有 hidden state,表示 Ngram2Vec 的映射函数,每个时刻获取到一个记忆向量。 对比以上两个预测,Memory 是以条件概率式的改变输出分布的。速记:

  • Memory:N-Gram 映射成记忆嵌入向量
  • Conditional: 预测概率引入外部条件信息

4.1 Engram 架构图

Engram 模块中具体分为三个部分:

  • 输入:input-hidden 和 memory embedding。 较为复杂的是 multi-heads hash memory 获取
  • Scaled Dot Product :通常会将该算子与 Attention 等同。但事实上,Engram 只是用这个算子计算, 例如,将时刻的特征为作为 query,在一个巨大的图书馆内,找到 1 本书(memory) 并投影为 key, Value。 通过 Scale Dot Product 算子,得出特征向量。完整输出为每个时刻的特征与memory特征的结合输出。
  • conv:每个时刻都有 conditional memory 加持的特征为,对于 context 长度为的序列则有矩阵, 此时通过在时序上的 short-Conv1D 则可以达到短距离建模的目的。

实现一个简易模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class EngramSimple(nn.Module):
def __init__(self,
memory_vocab_size,
dim,
kernel_size=2):
super().__init__()
self.memory_vocab_size = memory_vocab_size
self.memory_embd = nn.Embedding(memory_vocab_size, dim)
self.Wk = nn.Linear(dim, dim)
self.Wv = nn.Linear(dim, dim)

self.kernel_size = kernel_size
self.w_conv1d = nn.Parameter(torch.randn(kernel_size))

def forward(self, x, h):
"""
h: bsz, seq_len, dim
x: bsz, seq_len
"""
B, T, D = h.shape

# 1. Multi-heads-hash memory
hash_id = self.multi_head_hash(x)
h_memory = self.get_memory(hash_id) # B, T, D

# 2. Scale-Dot-Product Fusion
q, k, v = h, self.Wk(h_memory), self.Wv(h_memory)
gate = (q * k).sum(dim=2, keepdim=True)
v_ = gate * v # B, T, D

# 3. short-conv1d
out = self.short_conv1d(v_)
return out

相应的操作实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class EngramSimple(nn.Modules):
# ...

def multi_head_hash(self, x):
B, T = x.shape
hash_id = torch.randint(self.memory_vocab_size, (B,T)) # 随机hash
return hash_id

def get_memory(self, hash_id):
h_memory = self.memory_embd(hash_id)
return h_memory

def short_conv1d(self, v):
"""简化conv1d, 相邻时刻相加"""
v0 = v * self.w_conv1d[0]
v1 = v * self.w_conv1d[1]
v0[:, :, 1:] += v1[:,:,:-1]
return v0

官方论文公式和代码不容易理解,原因是 Engram 集成了 Hyper-Connection(HC)非 HC 版本可以称之为是单流 Engram,集成了HC得版本称之为多流Engram。

为了使得更易理解,将从单流版本 Engram 进行讲解。然后再进一步讨论 Hyper Connection 与 Engram 的集成。

1
2
3
4
5
6
7
8
9
10
11
12
13
def forward(self, x, h):
"""
h: bsz, seq_len, dim
x: bsz, seq_len
"""
# do ...

def forward_hc(self, x, h):
"""
h: bsz, seq_len, n_hc, dim
x: bsz, seq_len
"""
# do ...

4.2 Engram Without Hyper-Connection

Engram 详细讲解为:

  • Multi-heads-hash memory
  • Scale-Dot-Product Fusion
  • Short Conv1D

4.2.1 Multi-heads-hash memory

N-Gram 数据根据个 Hash 函数映射得到 个 hash id。并从完整记忆嵌入矩阵为记忆数量, 为单头维度)中得到个向量,第个哈希 ID 记为

Engram 取 3-Gram 和 2-Gram 的记忆,记为

所有数据进行拼接成一个向量得到:

此时一个 token 得到的多头哈希记忆变为为向量

多哈希类实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class MultiHeadsHash:
def __init__(self, max_memory_vocab_size, layer_id):
self.mods = torch.tensor([12582917, 25165843,
50331653, 100663319,
201326611, 402653189,
805306457, 1610612741]) # 素数
self.mods *= (layer_id+1) # 每层哈希模不同
self.max_memory_vocab_size = max_memory_vocab_size
self.layer_id = layer_id

def hash(self, x, mod, n_gram):
x_ = x.clone()
for i in range(1, n_gram):
x_[:, i:] *= x[:, :-i]
hash_id = x_ % mod
hash_id = hash_id % self.max_memory_vocab_size
return hash_id

def multi_head_hash(self, x, mods, n_gram):
hash_ids = []
for mod in mods:
hash_id = self.hash(x, mod, n_gram)
hash_ids.append(hash_id)
hash_ids = torch.stack(hash_ids, dim=-1) # bsz, seq_len, hash_head
return hash_ids

def get_all_hash_ids(self, x, max_n_gram):
ngram_hash_ids = []
for N in range(1, max_n_gram): # 2-gram, 3-gram
hash_ids = self.multi_head_hash(x, self.mods, N)
ngram_hash_ids.append(hash_ids)
return ngram_hash_ids # [ [bsz, seq_len, hash_head], [bsz, seq_len, hash_head] ]

memory 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ConditionalMemory(nn.Module):
"""
每层Memory不同, Engram memory 伴随 DecoderBlock
"""
def __init__(self,
config
):
super().__init__()

self.head_dim = config.head_dim
max_memory_vocab_size=config.max_memory_vocab_size
self.head_hash = config.head_hash

# self.memory_embds = [ nn.Embedding(max_memory_vocab_size,
# head_dim) for i in range(head_hash)]
self.memory_embds = nn.Embedding(max_memory_vocab_size * self.head_hash,
self.head_dim)
self.offset = torch.arange(self.head_hash) * max_memory_vocab_size
self.offset = self.offset[None, None, 1]

def forward(self, x, ngram_hash_ids):
bsz, seq_len = x.shape
n = len(ngram_hash_ids)

x += self.offset
ngram_memory = []
for hash_ids in ngram_hash_ids:
memory = self.memory_embds(hash_ids)
ngram_memory.append(memory)
h_memory = torch.cat(ngram_memory, dim = -1)

# flat
h_memory = h_memory.reshape(bsz, seq_len, n*self.head_hash*self.head_dim)
return h_memory # bsz, seq_len, 2 * dim

4.2.2 Scale-Dot-Product Fusion

此时每个时刻有, 将记忆向量维度对齐到 特征维度

其中 ,是可学习的投影矩阵 , 是键向量与值向量 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class EngramWithoutHC(nn.Module):
def __init__(self,
config,
layer_id=1,
):
super().__init__()

D = config.dim
self.max_n_gram = config.max_n_gram

# proj
memory_dim = config.head_dim * config.head_hash * (config.max_n_gram-1)
self.Wk = nn.Linear(memory_dim, D)
self.Wv = nn.Linear(memory_dim, D)
self.norm1 = nn.RMSNorm(D)
self.norm2 = nn.RMSNorm(D)

# op
self.hash = MultiHeadsHash(max_memory_vocab_size=config.max_memory_vocab_size,
layer_id=layer_id)
self.memory = ConditionalMemory(config)
# self.conv = ShortConv1D(config.dim,
# config.n_hc,
# config.kernel_size)

作为 query, 计算与记忆向量之间的标量门控值,用于衡量哈希到的记忆相关性。采用 Scale-Dot-Product算子

其中,是Sigmoid 激活函数, 如果当前内容检索到的内容不相关,其激活分数趋近于零,有效抑制其噪声,为标量门控值。

然后对标量分数缩放:

完整 context 变为。 Engram 实现为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class EngramWithoutHC(nn.Module):
# ...

def forward(self, h, x):
"""
h: bsz, seq_len, dim
hidden states

x: bsz, seq_len
input ids
"""

_, _, D = h.shape

ngram_hash_id = self.hash.get_all_hash_ids(x, self.max_n_gram)
h_memory = self.memory(x, ngram_hash_id)

# proj
q = self.norm1(h)
k = self.norm2(self.Wk(h_memory))
# score
gate = (q * k).sum(dim=-1, keepdim=True) / math.sqrt(D) # bsz, seq_len, 1
gate = torch.sigmoid(gate)

# value
v_ = gate * self.Wv(h_memory)

# Ignore Conv1D
# out = self.conv(v_) + v_

return out

值得注意的是,单个 query 与单个 key、value 的交互可以视作是注意力计算的一个特例。如果是 所有的记忆嵌入成为多个 key、value,那么就回到标准的注意力计算模式:单 query 和 多 key, Value。

论文标题的 Sparsity 即是在一个巨量存储的记忆里通过哈希方式,稀疏的选出特定的 memory 。给定

  • Memory Embedding 矩阵共有 600,000 条
  • 2-gram/3-gram记为 2 组数据, 设多头哈希为 8, 所选记忆条目为 16 条。

稀疏度:

4.2.3 Short Conv-1D

上述混合hidden与memory 产生了完整 context 变为。最终引入 Conv1D 卷积进行短距离序列建模。依次介绍 Engram 中的 Short Conv1D 细节。

4.2.3.1 Conv1D 时序数据处理

Conv1D 是一种时序维度的卷积方式,对于一维时序数据 , 卷积核大小 kernel_size = 3 记为。将时序数据增加 left-padding, 相邻三个元素进行加权运算,

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def fun_conv1d(x, w):
kernel_size = len(w)
x_len = len(x)
x_padding = [0] * (kernel_size-1) + x

x_conv1d = []
for i in range(x_len):
x_tmp = 0
for j in range(kernel_size):
x_tmp += x_padding[i+j] * w[j]
print(x_padding[i: i+kernel_size], '*' ,w, '->', x_tmp)
x_conv1d.append(x_tmp)

return x_conv1d

x = [1, 2, 3, 4]
w = [1, 10, 100]

print('x:', x, 'w:', w)
y = fun_conv1d(x, w)
print(y)

输出为

1
2
3
4
5
6
x: [1, 2, 3, 4] w: [1, 10, 100]
[0, 0, 1] * [1, 10, 100] -> 100
[0, 1, 2] * [1, 10, 100] -> 210
[1, 2, 3] * [1, 10, 100] -> 321
[2, 3, 4] * [1, 10, 100] -> 432
[100, 210, 321, 432]

上述代码的特点在于,时刻数据的输出仅与其前序数据相关,因此 Conv1D 也被称为深度因果卷积。可以直接使用 pytorch 卷积

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
kernel_size = 3
conv = nn.Conv1d(
in_channels=1,
out_channels=1,
kernel_size=kernel_size,
groups=1,
bias=False,
padding=(kernel_size - 1) * 1,
dilation=1,
)
x_len = len(x)
x_tensor = torch.tensor([x], dtype=torch.float32).unsqueeze(dim = 1)
print(x_tensor.shape) # B, C, T
conv.weight.data = torch.tensor([[[1,10,100]]],dtype=torch.float32)
y = conv(x_tensor)
print(y[0,0,:x_len])

打印结果与手动运算一致

1
2
torch.Size([1, 1, 4])
tensor([100., 210., 321., 432.], grad_fn=<SliceBackward0>)
4.2.3.2 Conv1D 时序特征处理

给定序列长度为,特征通道数为时序数据, 卷积核 ,每个特征维度上都有卷积权重;在 pytorch 中用 group 参数可以设定。

实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
B = 1
T = 10
C = 128 # channel, dim(D)
X = torch.randn(B, T, C)
print(X.shape)
X = X.transpose(1,2) # B, C, T

conv_3C = nn.Conv1d(in_channels=C, out_channels=C, kernel_size=kernel_size,
groups=C, padding=(kernel_size - 1) * 1, dilation=1, bias=False,)
print('shape_groupC:', conv_3C.weight.data.shape)
Y = conv_3C(X)
Y = Y.transpose(1,2)[:, :T]
print(Y.shape)

打印

1
2
3
torch.Size([1, 10, 128])
shape_groupC: torch.Size([128, 1, 3])
torch.Size([1, 10, 128])
4.2.3.3 Engram Short Conv-1D

回到 Engram,对融合的进行如下操作。

其中,是深度因果卷积,输出维度不变 ,为激活函数,维度不变 , 残差连接与卷积输出相加, :最终输出序列。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class ShortConv1DWithoutHC(nn.Module):
def __init__(self,
dim,
n_hc,
kernel_size,):
super().__init__()

dilation=1
self.total_dim = dim
self.conv = nn.Conv1d(
in_channels=self.total_dim,
out_channels=self.total_dim,
kernel_size=kernel_size,
groups=self.total_dim,
bias=False,
padding=(kernel_size - 1) * dilation, # 3
dilation=dilation,
)

self.norm = nn.RMSNorm( dim )
self.act_fn = nn.SiLU()

def forward(self, x):
B, T, C = x.shape

x_norm = self.norm(x)
x_norm = x_norm.transpose(1, 2) # B, C, T
y = self.conv(x_norm)
y = y[..., :T]

y = self.act_fn(y) # swiglu
y = y.transpose(1,2)

return y

ShortConv1D 通过卷积核范围内学习短距离特征表示,类似的滑窗注意力也是一种可控学习“短”距离特征表示的方法。

在距离上 ShortConv1D 是短的,思考 Engram 前一个 DecoderBlock 有注意力层,那么 DecoderBlock 输出的 hidden state 中每个时刻都有对前序数据进行加工。ShortConv1D 对多个包含历史的特征进行更复杂的加工处理。

至此,Engram 所有模块都得到了实现。

4.3 Engram With Hyper-Connection

Hyper Connection 是一种新的残差链接范式,DeepSeek 提出 mHC 采用流形约束提升 HC 训练的稳定性。

在 HC 类的网络流中,hidden state 由原本的变换为, 其中为 HC 扩展率。在网络中有 1 条残差分支和 N 条变换分支。

对比单流和多流 Scale-Dot-Product 对比:

  • HC 版本有多个投影矩阵
  • HC 版本中的是针对第计算的
  • HC 版本中的 RMSNorm 有 N*2 个

得到门控缩放 value。

带 HC 版本的 ShortConv1D 实现细节:

  • 在每个输入分支上进行 RMSNorm()
  • 向量拼接成一条向量
  • 标准的做 Conv1D
  • 将 Conv1D 的结果再切分成 HC 分支。

4.4 小结

  • Engram 的输出结果混合 Hidden State 与 Ngram 映射的记忆
  • Scale-Dot-Product 并非实现类似注意力特征的序列建模, ShortConv1D 才是增加短距离建模能力的直接操作
  • Conditional Memory 为模型的规模扩展增加新维度。

五、Engram pytorch实现

代码参考

去除了 tokenizer 压缩和 XOR 哈希,更深入代码本质。 摘选带 HC 的 Engram

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class Engram(nn.Module):
def __init__(self,
config,
layer_id=1,
):
super().__init__()

self.n_hc = config.n_hc
D = config.dim

memory_dim = config.head_dim * config.head_hash * (config.max_n_gram-1)

self.Wks = nn.ModuleList([ nn.Linear(memory_dim, D) for i in range(self.n_hc) ])
self.Wv = nn.Linear(memory_dim, D)
self.norm1 = [ nn.RMSNorm(D) for i in range(self.n_hc) ]
self.norm2 = [ nn.RMSNorm(D) for i in range(self.n_hc) ]

self.max_n_gram = config.max_n_gram

self.memory = ConditionalMemory(config)
self.hash = MultiHeadsHash(max_memory_vocab_size=config.max_memory_vocab_size,
layer_id=layer_id)

self.conv = ShortConv1D(config.dim,
config.n_hc,
config.kernel_size) # HC

class Engram(nn.Module):
def forward(self, h, x):
"""
h: bsz, seq_len, n_hc, dim
hidden states

x: bsz, seq_len
input ids
"""

_,_, _, D = h.shape

ngram_hash_id = self.hash.get_all_hash_ids(x, self.max_n_gram)
h_memory = self.memory(x, ngram_hash_id)

gates = []
for hc_idx in range(self.n_hc):
# proj
q = self.norm1[hc_idx](h[:,:, hc_idx, :])
k = self.norm2[hc_idx](self.Wks[hc_idx](h_memory))
# score
gate = torch.sum(q*k, dim=1, keepdim=True) / math.sqrt(D)
gate = torch.sigmoid(gate)
gates.append(gate)

# value
gates = torch.stack(gates, dim = 2) # bsz, seq_len, n_hc, 1
v = self.Wv(h_memory).unsqueeze(2) # bsz, seq_len, 1, dim
v_ = gates * v # bsz, seq_len, n_hc, dim

# Conv1D
out = self.conv(v_) + v_

return out
  • 实现易理解的multi-head hash
  • 实现Memory
  • 实现short-Conv1D
  • 实现 Engram 前向、模型前向

六、Engram与Attention

Engram 负责将学习的内容存在记忆中,从反向视角,梯度流向两部分权重

  • 模型常规权重
  • Memory

与 Conditonal Memory 相应的模型如 Transformer,可以称之为是 Conditional Encoding

Engram 负责短距离建模,Attention 负责长距离建模。 Attention 在语言模型中仍是核心序列表征学习机制。

七、Engram Infra

  • 训练:额外创建的 Embedding 可以采用词表并行训练,可以类似 MoE 实现 Dispatch-Combine All-to-All 模式。由于训练时,memory 参数分布在多层网络中, 需要频繁的做前向和梯度反向计算。

  • 推理:可以分层静态管理 Engram, paper 中提到推理时可以做预取,推理服务需要配套开发 Memory 管理

  • 硬件:上述infra分析建立在 计算-通信 维度,假如每个计算设备,都可以通过低廉的存储组成,那么可以直接从磁盘中预取操作,掩盖掉 memory-IO开销,消除 all-to-all 通信开销。

八、讨论

短序列一定要 N-gram 吗?为什么是哈希映射?

不一定。有研究实现字符级别分词并自适应从 Embedding 空间上合并。也可以通过一些规则来映射,如字节近期工作将序列 token 转化为 concept。

N-Gram 哈希特点是内容无感的,正如 BPE 算法,其通用性是建立在统计规则的。

Engram 学到了什么?

  • Engram 学到了短序列建模,类似的在 Yarn中提到,长文本的扩展一个重要处理是对特征信号高频绝对外推。
  • 对于一个基模,可以将额外习得的知识填入到 Embeeding 上,那么与 LoRA 区别在于,LoRA 是参数扩展, Engram 是输入扩展。
  • Engram 学会了记忆吗?准确来说他学会了检索

Engram 如何扩展模型容量?

LLM 模型的核心容量通常是在 FFN 中体现的,MoE 是一种高效的容量扩展方式。

Engram 扩展模型的容量范式则为,增加更多的 memory-embedding 对计算没影响,对 IO 有影响,相较廉价的闪存和磁盘存储突破模型容量限制。

另外可以扩展 n-gram 数量,引入新的 scaling 维度

Engram 的记忆可扩展性讨论

考虑 agent 类场景,则可以引入额外的类似的。任务/模态等 Conditional-memory

而 rag 知识库也可以作为 一种 Conditional-memory,其所查询的 embedding 则不是rag类 chunk2embeeding,它的 memory 是跟随查询上下文长度的,数据量更大。

外置 Embeeding 学习范式

围绕着 LLM 结构

  • Prefix-Finetuning, 将任务前缀用 embedding 描述,并投影到 key 特征空间, 进行 conditional 推理
  • NvEmbd:将 Decoder LLM 训练为 Encoder 模型用于做文本的向量化,其引入 latent embedding 学习向量相似度学习任务的对

九、总结

  • Engram 提供了一种新的外置检索记忆的学习模式。提高模型的适用性和扩展性;
  • Engram 建模就近的 N-gram 记忆,拓展长文本之路,反直觉的是要去增强短距离建模;
  • Engram 并未颠覆传统的 Attention 架构,从预测任务角度引入了 conditional memroy。

Reference

  1. Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
  2. deepseek-ai/Engram
  3. 【手撕Engram】DeepSeek 的 Conditional Memory 能取代 Attention 吗?(超长文、附代码)
  4. Deepseek Conditional Memory浅读
  5. Conditional Memory:DeepSeek 如何为大模型引入“第二条稀疏轴”——从 MoE 走向“计算 × 记忆”双稀疏的大模型新范式
  6. 流形约束超连接(mHC):Manifold-Constrained Hyper-Connections

DeepSeek Conditional Memory:Engram
https://mztchaoqun.com.cn/posts/D109_DeepSeek_Engram/
作者
mztchaoqun
发布于
2026年2月11日
许可协议