LLM(九)——Mixture-of-Depths Transformers

一、Mixture-of-Depths Transformers

MoD(Mixture-of-Depths)采用的技术类似于混合专家(Mixture of Experts,MoE) transformer,其中动态token级路由决策是在整个网络深度上做出的。然而,与 MoE 的想法不同,MoD要么将计算应用于像标准transformer那样的token,要么通过残差连接(Residual Connection)传递它,保持不变并节省计算资源。与MoE相反,这种路由策略应用于前向多层感知机(Multi-Layer Perceptron, MLP)和多头注意力(Multi-head Attention)。由于这也影响了处理的key和query,路由不仅决定了更新哪些token,还决定了哪些token可以被关注。

MoD技术还允许在性能和速度之间进行权衡。一方面,可以训练一个MoD transformer在最终的对数概率训练目标上比标准transformer提高1.5%,并且训练所需的时间相当。另一方面,可以将一个MoD transformer训练达到与isoFLOP最佳的传统transformer相同的训练损失,但每次前向传递使用的FLOPs少得多(高达50%)。这些结果表明,MoD transformer学会了智能路由策略(即跳过不必要的计算),因为它们可以在每次前向传递的FLOPs足迹较小的情况下,实现相同或更好的序列对数概率。
/posts/

1.1 MoD实现

1.1.1 定义计算预算

  • 通过限制序列中可以参与块计算(例如自注意力和 MLP)的token数量来定义计算预算,该预算将小于等效的标准transformer。
  • 为了定义计算预算,还必须理解容量的概念,它定义了构成给定计算输入的token总数,并且还确定使用条件计算的转换器的总 FLOP 数,而不是任何结果路由决策。
  • 并不是所有token都同样重要,某些token可能不需要进行自注意力和 MLP的计算,可以通过学习来识别这些token,因此与标准transformer相比,可以通过降低计算容量来定义每个前向传递更小的计算预算。

例如,在每个标准transformer模块中,自注意力和MLP的处理容量为,即序列与批处理中所有token的总数。相比之下,MoE transformers对每个专家的多层感知机(MLP)设定的容量小于 ,这样做是为了更平均地分配各个专家的计算总量。但是,因为它们在每个模块中使用了多个专家,所以它们的总容量大致相当于一个标准transformer。

1.1.2 对transformer模块进行路由

  • token的路由可以通过自注意力和 MLP 块或残差连接这两条路径之一来完成,后者的计算成本较低,从而导致块输出完全由其输入值决定。
  • 路由存在两个极端,一个极端是将每个token路由到每个块,就像在标准transformer中一样,而另一个极端是将所有token路由到每个块周围,这提供了一个更快的模型,但下游性能较差。更理想的方法假设为介于这两个极端之间,以获得与标准transformer相比具有更好性能和更快速度的最佳模型。

1.1.3 路由机制

与dropout相比,学习路由更可取,而dropout相对于路由token来说表现明显较差。作为学习路由的一部分,网络可以了解哪些token比其他token需要更多或更少的处理。模型使用 路由机制选择要处理的token。该模型使用学习路由器为每个token生成权重。选择权重最高的前 k 个token进行完整处理,而其余token则绕过自注意力和MLP块进行路由。

  • 基于token的选择:路由器在计算路径上生成每个token的概率分布,然后将token汇集到其选择的计算路径(可以是概率最高的路径)。该方案可能存在负载平衡问题,因为无法保证token在可能的路径之间适当划分。

  • 基于专家的选择:每个路径不是根据token的偏好选择个token,而不是选择其首选路径。它确保了完美的负载平衡,因为个token保证被传送到每条路径。此外,由于操作取决于路由器权重的大小,因此该路由方案允许相对路由权重来帮助确定哪些token最需要块的计算。路由器还可以通过适当设置权重来确保最关键的token位于中。因此,对于 MoD 方法来说,由于上述相对于token选择路由的优点,该方案是更优选的方法。

1.1.4 路由实现

每个 token 都由路由器处理以产生标量权重,然后使用前 个权重来选择将通过转换器块路由的 token 身份,该块包括自注意力和后续的 MLP。

假设在给定层 的长度为 的序列中拥有一组 token 嵌入;即 。给定 token 嵌入的路由器权重是线性投影产生的标量,

目标是使用这些路由器权重来确定块对每个 token 的计算的输出。假设 是路由器权重集 的第 个百分位数,其中 是每个批处理元素的用户定义容量(一个整数 ,定义给定函数将处理的序列中的标记数)。给定标记的块输出为:

这里, 是路由器值 (即“前 k 个”token)的token集, 包括自注意力和后续的 MLP。请注意,由于自注意力操作,给定token 的输出可能取决于其他token 的基数是 (或 ):用户定义的容量。因此,混合深度转换器相对于基线节省了计算资源,因为块计算 的输入包含的token比平常少(),从而使自注意力和 MLP 更便宜。

将函数 的输出乘以路由器权重。这会将路由器权重放在“梯度路径”上,从而使它们在语言建模任务的过程中受到梯度下降的力量(论文作者尝试了不同的版本,其中路由器权重也包含在绕过块计算的标记的计算路径上,但似乎足够了 - 并且在实现上更简单 - 仅在计算路径上包含那些不绕过块计算的标记的路由器权重)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

class MoD(nn.Module):
"""
Paper: https://arxiv.org/abs/2404.02258
"""

def __init__(self, cfg: Config) -> None:
super().__init__()
self.seq_len = cfg.seq_len
self.capacity_factor = cfg.capacity_factor
self.dim = cfg.d_model

self.transformer_decoder_block = Block(cfg)
self.router = nn.Linear(self.dim, 1, bias=False)
self.aux_router = nn.Sequential(
nn.Linear(self.dim,self.dim//2),
nn.SiLU(),
nn.Linear(self.dim//2,1),
)

def forward(
self, x: Tensor, mask, freqs_cis, mode="train", auxiliary_loss=False, *args, **kwargs
):
batch_size, seq_len, dim = x.shape

if mode == "inference":
return self.inference(x, *args, **kwargs)
# S = seq_len, C = capacity , C = int(seq_length * capacity_factor)
# page 6 above eq 1 | ( C<S ) | here top_k = beta
top_k = int(seq_len * self.capacity_factor)

# eq1 page 6
# scaler weights for each token
router_logits = self.router(x) # (x) batch,seq_len,dim -> r batch,seq_len,1

# 𝑟𝑙> 𝑃𝛽 (R) ... eqution 1
token_weights, token_index = torch.topk(router_logits, top_k, dim=1, sorted=False)

# now we have idx, we can copy this weights to another tensor and pass them to attn+mlp

# since its auto regressive model we need to keep casual nature of it
# that why we need sort the tokens by idx before we pass it to attn
selected_tokens, index = torch.sort(token_index, dim=1)

# select idx for copying for original tensor
indices_expanded = selected_tokens.expand(-1, -1, dim)

# This are fillted topk tokens with capactiy C
filtered_x = torch.gather(input=x, dim=1, index=indices_expanded) # -> batch, capacity, dim

x_out, _ = self.transformer_decoder_block(filtered_x, mask, freqs_cis)

# softmax router weights, aaah
token_weights = F.softmax(token_weights, dim=1)

# selecting router wight by idx ( in sorted maner)
r_weights = torch.gather(token_weights, dim=1, index=index)

# muliply by router weights, this add router in gradient stream
xw_out = r_weights * x_out

# batch_indices = torch.arange(batch_size).unsqueeze(-1).expand(-1, top_k)
# # # https://discuss.pytorch.org/t/when-inplace-operation-are-allowed-and-when-not/169583/2
# out = x.clone()
# # add back to resuidal strean
# out[batch_indices, selected_tokens.squeeze(-1),: ] += xw_out
# # ^ this can be done with torch.scatter_add
out = torch.scatter_add(input=x, dim=1, index=indices_expanded, src=xw_out)

if auxiliary_loss:
aux_loss = self.aux_loss( , router_logits, selected_tokens)
return out, aux_loss
return out, _

def aux_loss(self, x: Tensor, router_logits: Tensor, selected_tokens: Tensor):
batch_size, seq_len, dim = x.shape
# Page 7, Section 3.5 sampling
router_targets = torch.zeros_like(router_logits).view(
-1
) # i think torch to scatter will work here TODO
router_targets[selected_tokens.view(-1)] = 1.0
aux_router_logits = self.aux_router(x.detach().view(batch_size * seq_len, -1))
# aux_router_logits = F.sigmoid(aux_router_logits) # keep output in range [0,1)
# RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
# so binary_cross_entropy_with_logits == sigmoid + bce_loss
return F.binary_cross_entropy_with_logits(aux_router_logits.view(-1), router_targets)

1.1.5 成果

虽然专家选择路由具有许多优点,但它有一个明显的问题:top-k 操作是非因果的。这意味着给定 token 的路由权重是否属于序列的 top-k 取决于其后面的 token 的路由权重值,而在自回归采样时无法访问这些值。为了解决推理期间 top-k 路由的非因果性质,训练辅助预测器网络来预测token是否会出现在 top-k 中,从而实现高效的自回归采样。MoD 模型使用标准语言建模目标进行训练,并添加了用于自回归采样的辅助预测器损失

MoD hyperparameter tuning

最佳的MoD transformer在达到更低的损失值的同时,其参数数量也更多。存在一些参数规模较小的MoD模型,在其超参数设定下虽然不是isoFLOP最优的,但它们的性能与最优基准模型相当或更好,同时训练速度更快。每隔一个块进行路由对于实现强性能至关重要,将容量减少到总序列的12.5%,即有87.5%的token绕过块时,可以带来渐进的性能提升,但减少到12.5%这个比例以下性能开始下降。

isoFLOP分析

12.5% 容量的 MoD 变体用于对 6e18、2e19 和 1e20 FLOP 执行 isoFLOP 分析,训练模型大小从 60M 到 3B 参数不等。下图表示每次前向传递的相对 FLOP(标准化为 isoFLOP 最佳基线)。存在着既能与等效浮点操作最优基线保持一致的性能又能提高训练速度的MoD transformer,这既因为它们每参数使用的浮点操作数更少,也因为它们使用的参数总数更少。只需调整MoD配置(即容量和路由频率)下的模型规模,使其每次前向传递的浮点运算数与等效浮点操作最优基准相等,就可以获得该配置下表现最佳的MoD变体。

路由分析

下图将表示路由决策,并观察序列末尾的深蓝色垂直带。在下面的直方图中,可以观察到路由器权重的分布如辅助损失所示:大约 12.5% 的权重高于 0.5% 和 87.5%。采取了大胆绕过部分块的路由策略,transformer模型相对于基准模型仍能取得性能上的提升。

自回归性能评估

在自回归采样期间从训练中的非因果 top- 路由方案切换到基于因果预测器的方法,性能略有下降。和训练场景相似,有些MoD变种在性能上超越了等效浮点操作最优的基准模型,同时它们每次前向传递所需的浮点计算量更少。这些结果表明,MoD transformers提供的计算效率提升不仅限于训练场景。

1.1.6 Mixture-of-Depths-and-Experts (MoDE)

Mixture of Depths(MoD)技术与Mixture of Experts(MoE)技术的融合。

Staged MoDE:在自注意力步骤之前将token绕向区块或向区块路由。
Integrated MoDE:通过在传统MLP专家中集成“无操作”专家来实现MoD路由

在这两种变体中,Staged MoDE允许token跳过自注意力步骤,而Integrated MoDE简化了路由机制。Integrated MoDE 是更可取的,因为token明确地学习选择围绕专家的剩余路径,而不是首选专家但在实现容量减少时被丢弃。

参考

  1. Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
  2. 《Mixture-of-Depths: Dynamically allocating compute in transformer-based language models》精华摘译
  3. Mixture of Depth is Vibe
  4. Mixture-of-Depths: A new approach to efficiently allocate compute in Transformer Language Models
  5. Mixture of Depths: Dynamic Compute Allocation for Language Models
  6. Mixture-of-Depths: Dynamically allocating compute in transformer-based language models

LLM(九)——Mixture-of-Depths Transformers
https://mztchaoqun.com.cn/posts/D47_MoD/
作者
mztchaoqun
发布于
2024年11月15日
许可协议