深度学习新突破:哈希层与阶梯注意力模型,分开考量参数与计算量提升性能
深度学习模型新突破参数与计算量分开考量两种新方法助力性能提升在讨论深度学习模型能力时通常只关注由参数数量衡量的模型大小这一指标而运行模型所需的计算量常被忽视因其与模型大小相关从业者常将二者视为一体。多数情况下每个参数每次输入仅参与一次计算如 100 万个参数的模型处理一个输入约需 100 万次浮点运算这适用于前馈、循环及 Transformer 模型。我们宣布发表两种新方法助力进一步研究此重要问题表明应将模型计算量与大小分开考虑。其一可在不增加计算量的情况下增大模型大小以提升性能[第一篇论文](https://arxiv.org/abs/2106.04426)提出通过引入哈希层实现其二[第二篇论文](https://arxiv.org/abs/2106.04279)表明能在不增加新参数的情况下增加模型计算量来显著提升性能其提出新的阶梯注意力模型达成此目标。综合来看这些结果为深度学习模型带来新思考方式要求区分参数和计算量概念如此能设计出更适配可用资源的强大模型。哈希层近年来为在语言任务取得出色成果Transformer 模型规模不断增大参数数量扩展到数十亿甚至数万亿。但更大模型需更多计算量实用性降低。让大型模型减少计算量的方法之一是采用[稀疏专家混合](https://arxiv.org/abs/1701.06538)MoE方法每个专家有自己参数仅用于输入的一小部分输入仅被路由到部分专家减少计算量。[近期研究](https://arxiv.org/abs/2006.16668)表明可借此高效增大 Transformer 模型规模。MoE 关键是路由机制[我们的论文](https://arxiv.org/abs/2106.04426)提出基于输入令牌哈希的路由机制与以往研究不同哈希 MoE 更简单无需学习过程或改变目标函数字典中每个单词被分配给固定专家可随机选择或按分布均衡分配。虽方法简单但在一些具有挑战性的语言和对话任务中表现出色。在 pushshift.io Reddit 语言建模任务中我们的哈希机制优于基于学习的 [Switch](https://arxiv.org/abs/2101.03961) 基线模型尤其在专家数量较多时。最大的 12.8 亿参数模型特定输入仅使用 17%参数。我们进一步在更大数据集上训练 45 亿参数模型发现哈希机制表现优于有竞争力的稀疏 MoE 模型 [BASE](https://arxiv.org/abs/2103.16716)。专家分配的自然均衡性使训练过程在集群上更高效且可扩展实验中与 BASE 相比每秒更新次数提高约 11%且随专家层数量增加差异更明显。阶梯注意力为提升性能给 Transformer 添加更多参数是热门研究话题但增加其计算量的研究较少原因是标准 Transformer 架构使计算量和参数紧密关联增加计算量困难。[我们的论文](https://arxiv.org/abs/2106.04279)引入替代架构分离计算量和参数概念表明增加计算量是提升性能的途径。具体提出对 Transformer 循环应用的阶梯Staircase和梯子Ladder模型。梯子模型简单多次堆叠相同的 Transformer使参数多次参与计算在保持模型大小不变时增加计算量此简单修改在语言建模和对话等实际任务中带来显著性能提升表明增加计算量是有吸引力的研究方向。阶梯模型和梯子模型一样堆叠 Transformer但将每个 Transformer 向前移动多个时间步只要有输入就能继续堆叠形成类似阶梯模型。与 Transformer 不同其连续性使阶梯模型在时间上有循环性对维持内部状态跟踪变化至关重要。在一些简单构造任务中前馈模型会遇到困难阶梯模型可轻松解决。且因每个参数计算量增加阶梯模型在语言建模任务中也能像梯子模型一样提升性能。为何不同时采用引入这两种方法后自然会问能否结合答案是肯定的。这两种方法带来的改进似乎相互独立哈希层 梯子模型的组合相比单独使用任何一种方法都有显著性能提升。综合来看这两种方法能对参数大小和计算量精细控制带来改进。总之我们的研究探讨了计算量与参数大小问题表明思考新方法时应区别对待二者而非像许多标准机器学习模型那样捆绑。具体提出两种新架构分别探索增加参数大小或计算量的权衡并展示如何结合其思想。我们相信这种思考方式及新方法应用将为机器学习研究开辟富有成效的道路。若想了解更多细节请阅读[哈希层](https://arxiv.org/abs/2106.04426)和[阶梯注意力](https://arxiv.org/abs/2106.04279)的论文。代码可在[此处](https://github.com/facebookresearch/ParlAI/tree/main/projects/params_vs_compute/hash_ladder)获取。