Datewhale组队学习
1. 编码器(Encoder)
1.1 Encoder 工作流程
输入阶段
初始输入:
- 输入数据(如文本中的单词)首先通过嵌入层(Input Embedding)转换为向量表示。
- 为了捕捉序列中元素的位置信息,位置编码(Position Embedding)被添加到嵌入向量中。位置编码可以是固定的(如正弦余弦函数)或可学习的。
- 第一个Encoder子模块接收嵌入和位置编码的组合输入。
后续 Encoder 输入:
- 除了第一个Encoder子模块外,其他Encoder子模块从前一个Encoder接收输入,形成链式结构。
核心处理阶段
多头自注意力层(Multi-Head Self-Attention):
- 输入序列通过线性变换生成查询(Q)、键(K)、值(V)向量。
- 通过缩放点积注意力(Scaled Dot-Product Attention)计算序列中不同位置之间的关联关系。
- 注意力分数通过softmax归一化,加权求和后得到自注意力输出。
前馈层(Feedforward Layer):
- 自注意力层的输出经过一个全连接网络(通常包含两层线性变换和ReLU激活函数),提取更高层次的特征。
- 前馈层的输出传递给下一个Encoder子模块(如果不是最后一个Encoder)。
残差与归一化阶段
残差连接(Residual Connection):
- 自注意力层和前馈层的输入与输出通过残差连接相加(如输入为 $x$,输出为 $y$,则结果为 $x + y$)。
- 残差连接缓解了深层网络中的梯度消失问题,使模型更容易训练。
层归一化(Layer Normalization):
- 在残差连接后,对输出进行层归一化,加速模型收敛并提高泛化能力。
- 归一化操作对每一层的神经元输入进行标准化处理,使其均值为0,方差为1。
1.2 Encoder 组成成分
每个Encoder子模块包含以下部分:
- 多头自注意力层(Multi-Head Self-Attention):捕捉序列内部关系。
- 残差连接(Residual Connection):缓解梯度消失问题。
- 层归一化(Layer Normalization):加速训练并提高稳定性。
- 前馈全连接网络(Position-wise Feed-Forward Networks):对特征进行非线性变换。
2. 多头自注意力(Multi-Head Self-Attention)
2.1 缩放点积注意力(Scaled Dot-Product Attention)
输入处理:
输入序列通过线性变换生成查询(Q)、键(K)、值(V)向量:
$$ Q = W_Q \cdot x, \quad K = W_K \cdot x, \quad V = W_V \cdot x $$
- 其中,$W_Q$、$W_K$、$W_V$ 是可学习的权重矩阵。
注意力分数计算:
通过点积计算查询和键之间的相似度:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
- 缩放因子 $\sqrt{d_k}$ 防止点积值过大,避免softmax进入饱和区。
Mask机制:
- 在训练时,mask用于遮蔽未来信息(如解码器中);在推理时,mask用于控制输出序列的生成。
2.2 多头注意力机制(Multi-Head Attention)
多头计算:
输入序列被映射到多个子空间(头),每个头独立计算注意力:
$$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
所有头的输出拼接后通过线性变换得到最终输出:
$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O $$
优点:
- 并行化:多个头可以并行计算,提升效率。
- 多样性:每个头关注输入序列的不同方面,增强模型表示能力。
- 泛化性:捕捉局部和全局依赖关系,提高模型泛化能力。
2.3 自注意力机制(Self-Attention)
工作原理:
- 查询(Q)、键(K)、值(V)均来自同一输入序列。
- 通过自注意力机制,模型能够捕捉序列内部元素之间的关系(如句子中单词的语义关联)。
优点:
- 参数效率:参数数量较少,计算复杂度为 $O(n^2d)$。
- 全局信息:一步计算即可捕捉序列的全局依赖关系。
- 并行化:不依赖序列顺序,适合并行计算。
缺点:
- 计算量大:序列较长时,计算复杂度较高。
- 位置信息缺失:自注意力机制本身无法捕捉序列顺序,需通过位置编码补充。
3. 交叉注意力(Cross Attention)
3.1 简述
概念:
- 交叉注意力允许一个序列(查询序列)关注另一个序列(键-值序列),建立两者之间的联系。
序列维度要求:
- 查询序列和键-值序列的维度必须相同。
应用场景:
- 多模态任务(如文本与图像对齐)、序列到序列模型(如机器翻译中的编码器-解码器交互)。
3.2 操作
- 查询序列(Q):定义输出的序列长度。
- 键-值序列(K, V):提供输入信息,用于计算与查询序列的相似度。
4. Cross Attention 和 Self Attention 的区别
Self Attention:
- 查询、键、值来自同一序列。
- 用于捕捉序列内部关系(如句子中单词的语义关联)。
- 适用于序列内部依赖建模。
Cross Attention:
- 查询来自一个序列,键和值来自另一个序列。
- 用于跨序列信息交互(如编码器-解码器之间的信息传递)。
- 适用于多模态任务和序列到序列模型。
5. 前馈全连接网络(Position-wise Feed-Forward Networks)
结构:
包含两个全连接层和一个ReLU激活函数:
$$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$
- 第一层将输入维度从 $d_{model}$ 扩展到 $d_{ff}$(如2048),第二层将其还原为 $d_{model}$。
作用:
- 增强模型的非线性表达能力。
- 对自注意力层输出的特征进行进一步变换。
6. Add & Norm
残差连接(Add):
- 输入与子层输出相加,缓解梯度消失问题。
层归一化(Norm):
- 对残差连接的输出进行归一化,加速训练并提高稳定性。
总结
- Encoder:通过多头自注意力和前馈层处理输入序列,残差连接和层归一化提高训练稳定性。
- 多头自注意力:通过多个注意力头并行计算,增强模型表示能力。
- 自注意力:捕捉序列内部关系,适用于序列内部依赖建模。
- 交叉注意力:用于不同序列之间的信息交互,适用于多模态任务。
- 前馈网络:增强模型的非线性表达能力。
- Add & Norm:提高训练稳定性和收敛速度。
欢迎各位大佬在评论区交流&批评指正!
Datewhale组队学习
参考链接
评论 (0)