【2025.1 Datawhale AI+X 共学活动】Fun-Transformer —— Task3:Encoder

【2025.1 Datawhale AI+X 共学活动】Fun-Transformer —— Task3:Encoder

Simuoss
2025-01-21 / 0 评论 / 6 阅读 / 正在检测是否收录...

Datewhale组队学习

1. 编码器(Encoder)

1.1 Encoder 工作流程

  1. 输入阶段

    • 初始输入

      • 输入数据(如文本中的单词)首先通过嵌入层(Input Embedding)转换为向量表示。
      • 为了捕捉序列中元素的位置信息,位置编码(Position Embedding)被添加到嵌入向量中。位置编码可以是固定的(如正弦余弦函数)或可学习的。
      • 第一个Encoder子模块接收嵌入和位置编码的组合输入。
    • 后续 Encoder 输入

      • 除了第一个Encoder子模块外,其他Encoder子模块从前一个Encoder接收输入,形成链式结构。
  2. 核心处理阶段

    • 多头自注意力层(Multi-Head Self-Attention)

      • 输入序列通过线性变换生成查询(Q)、键(K)、值(V)向量。
      • 通过缩放点积注意力(Scaled Dot-Product Attention)计算序列中不同位置之间的关联关系。
      • 注意力分数通过softmax归一化,加权求和后得到自注意力输出。
    • 前馈层(Feedforward Layer)

      • 自注意力层的输出经过一个全连接网络(通常包含两层线性变换和ReLU激活函数),提取更高层次的特征。
      • 前馈层的输出传递给下一个Encoder子模块(如果不是最后一个Encoder)。
  3. 残差与归一化阶段

    • 残差连接(Residual Connection)

      • 自注意力层和前馈层的输入与输出通过残差连接相加(如输入为 $x$,输出为 $y$,则结果为 $x + y$)。
      • 残差连接缓解了深层网络中的梯度消失问题,使模型更容易训练。
    • 层归一化(Layer Normalization)

      • 在残差连接后,对输出进行层归一化,加速模型收敛并提高泛化能力。
      • 归一化操作对每一层的神经元输入进行标准化处理,使其均值为0,方差为1。

1.2 Encoder 组成成分

每个Encoder子模块包含以下部分:

  • 多头自注意力层(Multi-Head Self-Attention):捕捉序列内部关系。
  • 残差连接(Residual Connection):缓解梯度消失问题。
  • 层归一化(Layer Normalization):加速训练并提高稳定性。
  • 前馈全连接网络(Position-wise Feed-Forward Networks):对特征进行非线性变换。

2. 多头自注意力(Multi-Head Self-Attention)

2.1 缩放点积注意力(Scaled Dot-Product Attention)

  1. 输入处理

    • 输入序列通过线性变换生成查询(Q)、键(K)、值(V)向量:

      $$ Q = W_Q \cdot x, \quad K = W_K \cdot x, \quad V = W_V \cdot x $$

    • 其中,$W_Q$、$W_K$、$W_V$ 是可学习的权重矩阵。
  2. 注意力分数计算

    • 通过点积计算查询和键之间的相似度:

      $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$

    • 缩放因子 $\sqrt{d_k}$ 防止点积值过大,避免softmax进入饱和区。
  3. Mask机制

    • 在训练时,mask用于遮蔽未来信息(如解码器中);在推理时,mask用于控制输出序列的生成。

2.2 多头注意力机制(Multi-Head Attention)

  1. 多头计算

    • 输入序列被映射到多个子空间(头),每个头独立计算注意力:

      $$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$

    • 所有头的输出拼接后通过线性变换得到最终输出:

      $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W_O $$

  2. 优点

    • 并行化:多个头可以并行计算,提升效率。
    • 多样性:每个头关注输入序列的不同方面,增强模型表示能力。
    • 泛化性:捕捉局部和全局依赖关系,提高模型泛化能力。

2.3 自注意力机制(Self-Attention)

  1. 工作原理

    • 查询(Q)、键(K)、值(V)均来自同一输入序列。
    • 通过自注意力机制,模型能够捕捉序列内部元素之间的关系(如句子中单词的语义关联)。
  2. 优点

    • 参数效率:参数数量较少,计算复杂度为 $O(n^2d)$。
    • 全局信息:一步计算即可捕捉序列的全局依赖关系。
    • 并行化:不依赖序列顺序,适合并行计算。
  3. 缺点

    • 计算量大:序列较长时,计算复杂度较高。
    • 位置信息缺失:自注意力机制本身无法捕捉序列顺序,需通过位置编码补充。

3. 交叉注意力(Cross Attention)

3.1 简述

  • 概念

    • 交叉注意力允许一个序列(查询序列)关注另一个序列(键-值序列),建立两者之间的联系。
  • 序列维度要求

    • 查询序列和键-值序列的维度必须相同。
  • 应用场景

    • 多模态任务(如文本与图像对齐)、序列到序列模型(如机器翻译中的编码器-解码器交互)。

3.2 操作

  1. 查询序列(Q):定义输出的序列长度。
  2. 键-值序列(K, V):提供输入信息,用于计算与查询序列的相似度。

4. Cross Attention 和 Self Attention 的区别

  1. Self Attention

    • 查询、键、值来自同一序列。
    • 用于捕捉序列内部关系(如句子中单词的语义关联)。
    • 适用于序列内部依赖建模。
  2. Cross Attention

    • 查询来自一个序列,键和值来自另一个序列。
    • 用于跨序列信息交互(如编码器-解码器之间的信息传递)。
    • 适用于多模态任务和序列到序列模型。

5. 前馈全连接网络(Position-wise Feed-Forward Networks)

  1. 结构

    • 包含两个全连接层和一个ReLU激活函数:

      $$ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 $$

    • 第一层将输入维度从 $d_{model}$ 扩展到 $d_{ff}$(如2048),第二层将其还原为 $d_{model}$。
  2. 作用

    • 增强模型的非线性表达能力。
    • 对自注意力层输出的特征进行进一步变换。

6. Add & Norm

  1. 残差连接(Add)

    • 输入与子层输出相加,缓解梯度消失问题。
  2. 层归一化(Norm)

    • 对残差连接的输出进行归一化,加速训练并提高稳定性。

总结

  • Encoder:通过多头自注意力和前馈层处理输入序列,残差连接和层归一化提高训练稳定性。
  • 多头自注意力:通过多个注意力头并行计算,增强模型表示能力。
  • 自注意力:捕捉序列内部关系,适用于序列内部依赖建模。
  • 交叉注意力:用于不同序列之间的信息交互,适用于多模态任务。
  • 前馈网络:增强模型的非线性表达能力。
  • Add & Norm:提高训练稳定性和收敛速度。

欢迎各位大佬在评论区交流&批评指正!

Datewhale组队学习


参考链接

  1. Datawhale AI+X 共学活动 Fun-Transformer —— Task3:Encoder
  2. Attention Is All You Need
0

评论 (0)

取消