Transformer 中 Mask、Attention 與前向/前饋 易混淆點(diǎn)整理


1?? create_padding_mask 的實(shí)現(xiàn)與作用

Q: create_padding_mask 在項(xiàng)目中實(shí)現(xiàn)在哪里?它的輸入是什么?有什么作用?
A: 實(shí)現(xiàn)在 transformer/create_mask.py(或文檔 docs/code?dasm.md 中的示例)。

def create_padding_mask(src, tgt, pad_idx=0):
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)   # (B,1,1,src_len)
    tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)   # (B,1,tgt_len,1)

    tgt_len = tgt.size(1)
    look_ahead_mask = torch.ones(tgt_len, tgt_len).tril().bool() \
                      .unsqueeze(0).unsqueeze(0)          # (1,1,tgt_len,tgt_len)

    tgt_mask = tgt_mask & look_ahead_mask.to(tgt.device)
    return src_mask, tgt_mask
  • 輸入src(源序列張量,形狀 (batch, src_len)),tgt(目標(biāo)序列張量,形狀 (batch, tgt_len)),以及可選的 pad_idx(默認(rèn) 0)。
  • 輸出src_mask(僅 padding mask,形狀 (batch,1,1,src_len))和 tgt_mask(padding?+?look?ahead mask,形狀 (batch,1,tgt_len,tgt_len))。
  • 作用:在注意力計(jì)算前屏蔽填充位置(防止無意義的 PAD 影響注意力),并在解碼器自注意力中加入因果遮罩,防止模型看到未來 token。

2?? src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2) 的意義

Q: 這行代碼具體做了什么?
A:

  1. (src != pad_idx) 生成布爾張量,True 表示該位置不是 PAD,False 表示是 PAD。
  2. .unsqueeze(1) 在第 1 維插入大小為 1 的維度,使形狀變?yōu)?(batch,1,src_len)。
  3. .unsqueeze(2) 再插入第 2 維,得到 (batch,1,1,src_len)
    這樣得到的張量可以在注意力得分矩陣 (batch, n_heads, seq_len, seq_len) 上廣播,masked_fill(mask == 0, -1e9) 將 PAD 位置的分?jǐn)?shù)設(shè)為極小值,確保 softmax 后權(quán)重幾乎為 0。

3?? Masked Multi?Head Attention 與 Encoder 的 Padding Mask

Q: 論文圖中 Encoder 的 Multi?Head Attention 沒有 “masked”,但仍需要 padding mask,這兩者是否沖突?
A: 不沖突。

  • Encoder 只需要 padding mask 來屏蔽填充位置。
  • Decoder 自注意力 需要 masked Multi?Head Attention,即 padding mask + look?ahead mask,因?yàn)樗仨毞乐箍吹轿磥?token。
    論文圖中只在 Decoder 的自注意力旁標(biāo)記 “masked”,是因?yàn)橹挥心抢镄枰蚬谡郑籈ncoder 的 padding mask 被視為實(shí)現(xiàn)細(xì)節(jié),未在圖中顯式標(biāo)出。

4?? Masked Language Modeling (MLM) 與 Masked Multi?Head Attention 的關(guān)系

Q: MLM 與 Masked Multi?Head Attention 有何聯(lián)系與區(qū)別?
A:

維度 Masked Language Modeling Masked Multi?Head Attention
層級(jí) 訓(xùn)練任務(wù)(數(shù)據(jù)預(yù)處理 & 損失) 模型層(注意力計(jì)算)
遮罩對(duì)象 輸入 token 被替換為 [MASK](真實(shí)詞被隱藏) 注意力分?jǐn)?shù)矩陣的 位置 被屏蔽(padding?+?future)
目的 學(xué)習(xí)雙向上下文表示,預(yù)測(cè)被遮蓋的詞 保證解碼器自回歸(因果)并屏蔽填充
實(shí)現(xiàn) 在數(shù)據(jù) pipeline 中隨機(jī) mask token,計(jì)算交叉熵僅在 mask 位置 create_padding_mask 中生成布爾 mask,scaled_dot_product_attentionmasked_fill 使用

兩者都使用 “mask” 這個(gè)概念,但 MLM任務(wù)層面的 目標(biāo),MMHA模型層面的 機(jī)制。


5?? create_padding_mask 與 MLM 的關(guān)聯(lián)

Q: create_padding_mask 與 MLM 有直接關(guān)聯(lián)嗎?
A: 沒有直接關(guān)聯(lián)。create_padding_mask 只處理 padding(填充)位置的遮罩,用于注意力層;MLM 處理 隨機(jī)遮蓋 token,用于預(yù)訓(xùn)練目標(biāo)。它們?cè)谕荒P椭锌梢怨泊妫ɡ?BERT 使用 create_padding_mask 防止注意力關(guān)注 PAD,同時(shí)在預(yù)訓(xùn)練階段使用 MLM),但實(shí)現(xiàn)上是獨(dú)立的。


6?? Transformer 前向傳播的完整步驟(對(duì)應(yīng)論文 Figure?2)

  1. Embedding + Positional Encoding(源 & 目標(biāo))
  2. Mask 生成:調(diào)用 create_padding_mask(src, tgt)src_mask, tgt_mask。
  3. Encoder 堆疊(N 層)
    • Multi?Head Self?Attention(使用 src_mask
    • Add & LayerNorm
    • Feed?Forward Network (FFN)
    • Add & LayerNorm
  4. Decoder 堆疊(N 層)
    • Masked Multi?Head Self?Attention(使用 tgt_mask
    • Add & LayerNorm
    • Encoder?Decoder Attention(使用 src_mask
    • Add & LayerNorm
    • FFN
    • Add & LayerNorm
  5. 線性投影 + Softmax → 輸出 logits(詞表概率)。

7?? 前向傳播 vs. 前饋網(wǎng)絡(luò)(FFN)

  • 前向傳播:模型整體的計(jì)算流程,從輸入到輸出 logits,包含嵌入、注意力、FFN、殘差、歸一化等所有子模塊。
  • 前饋網(wǎng)絡(luò):在每個(gè) Encoder/Decoder 層內(nèi)部的子模塊,僅對(duì)每個(gè)位置的向量進(jìn)行兩層全連接變換(Linear → ReLU → Linear),隨后加殘差和 LayerNorm。它是前向傳播中的一個(gè)組成部分,負(fù)責(zé)提供位置級(jí)的非線性表達(dá)能力。

8?? 前向傳播是否包含反向傳播?

Q: 前向傳播的過程中會(huì)包括反向傳播嗎?
A: 不會(huì)。前向傳播只負(fù)責(zé)計(jì)算模型輸出并緩存中間激活;反向傳播在計(jì)算完損失后通過 loss.backward() 觸發(fā),用于梯度計(jì)算和參數(shù)更新。訓(xùn)練循環(huán)典型順序?yàn)椋?/p>

optimizer.zero_grad()
output = model(src, tgt)          # 前向傳播
loss = criterion(output, target)  # 計(jì)算損失
loss.backward()                  # 反向傳播
optimizer.step()                 # 參數(shù)更新

推理時(shí)只執(zhí)行前向傳播,不會(huì)進(jìn)行反向傳播。


9?? 代碼位置概覽

功能 文件
create_padding_mask transformer/create_mask.py
Encoder Layer transformer/Encoder.py
Decoder Layer transformer/Decoder.py
Multi?Head Attention transformer/MHA.py
Feed?Forward Network transformer/FFN.py
整體模型(forward) transformer/model.py
示例訓(xùn)練/測(cè)試 transformer/test.pyTransformer.ipynb

結(jié)語
以上即為本項(xiàng)目中關(guān)于 mask、attention、前向/前饋、MLM 與 MMHA 的完整問答匯總

附錄
Paper: https://arxiv.org/abs/1706.03762
Transformer?from?Scratch: https://github.com/Breeze648/Transformer-from-Scratch

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容