Transformer
RNN 的主要问题:
- 梯度消失/爆炸:长距离依赖难以学习
- 顺序计算:无法并行处理序列
- 信息瓶颈:最后时刻隐藏状态需承载全部信息
Transformer 的改进:
- 并行计算:同时处理整个序列
- 自注意力机制:直接建立任意位置间的联系
- 位置编码:显式注入位置信息
结构
图:Transformer 单元
图:Transformer 的详细结构
- 输入
- 编码器输入
- 解码器输入
- 输出
- 线性层
- Softmax 层
- 编码器
- 由 N 个编码器层堆叠而成
- 每个编码器层由两个子层连接结构组成
- 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
- 第二个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
- 解码器
- 由 N 个解码器层堆叠而成
- 每个解码器层由三个子层连接结构组成
- 第一个子层连接结构包括一个带掩码的-多头自注意力子层和规范化层以及一个残差连接
- 第二个子层连接结构包括一个多头注意力子层(编码器到解码器)和规范化层以及一个残差连接
- 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
核心组件
自注意力机制(Self-Attention)
在传统的神经网络处理序列时,模型只能一步步按顺序处理,难以捕捉长距离依赖关系。自注意力机制就是为了让序列中的每个元素都能直接与序列中所有其他元素进行交互,无论它们直接的距离多远。
定义:
| 符号 | 维度 | 含义 |
|---|---|---|
| 输入矩阵(n=序列长度,d=特征维度) | ||
| Query 矩阵(查询向量) | ||
| Key 矩阵(键向量) | ||
| Value 矩阵(值向量) | ||
| 可学习参数矩阵 |
TIP
表示当前需要关注的信息或问题,用于确定输入序列中哪些部分与当前任务相关 用于匹配查询,通过计算相似度判断输入序列中哪些元素与查询匹配 存储实际信息
推导过程
将输入转换为 Query、Key、Value:
计算注意力分数:
生成注意力权重矩阵:
得到最终注意力输出:
带掩码自注意力层(Masked Multi-head attention)
编码时,对于
解码时,对于
多头注意力(Multi-Head Attention)
其中:
位置编码(Positional Encoding)
与 RNN 和 LSTM 等顺序算法不同,Transformer 没有内置机制来捕获句子中单词的相对位置,所以在 Transformer 的 encoder 和 decoder 的输入层中,使用了 Positional Encoding,使得最终的输入满足:
原始正弦编码公式:
前馈网络(Feed Forward Network)
包括两个线性变换+ReLU 激活:
计算复杂度
当输入批次大小为
Self-Attention 层
- 计算
、 、
输入输出
计算量为:
TIP
矩阵加法运算考虑偏差 bias计算量就是
- 计算
输入输出
- Softmax 与加权求和
Softmax 计算量较小,通常忽略。
输入输出
- 输出投影
线性变换将结果映射回
输入输出
MLP 层
- 线性层(扩展层)
输入输出
- 线性层(压缩层)
输入输出
logits
Logits 层是将最终的 Transformer 隐藏层输出(维度
输入输出
总的计算复杂度
空间复杂度
大模型在训练过程中通常采用混合精度训练,中间激活值一般是 float16 或者 bfloat16 数据类型的。在分析中间激活的显存占用时,假设中间激活值是以 float16 或 bfloat16 数据格式来保存的,每个元素占了 2 个 bytes,dropout 操作的 mask 矩阵,每个元素只占 1 个 bytes。需要保存的中间激活占用显存大小计算如下:
Self-Attention 层
、 、 共享一个输入 ,则显存占用为 - 对于
,两个张量形状都是 ,显存占用为 - 对于
,函数输入 形状为 ,显存占用为 - 计算完
,会进行 dropout,需要保存一个 mask 矩阵,其形状与 相同,显存占用为 - 计算
,二者占用显存大小为 - 计算输出映射和一个 dropout 操作,二者占用显存大小为
综上,Self-Attention 层的显存占用为
MLP 层
- 第一个线性层的输入占用显存
- 激活函数的输入占用显存
- 第二个线性层的输入占用显存
- 最后的 dropout 操作需要保存的 mask 矩阵占用显存
综上,MLP 层的显存占用为
LN
Self-Attention 层和 MLP 层分别对应了一个 LN,其输入占用显存为
总的空间复杂度
问题
Transformer 的计算复杂度为:
可以注意到,