1?? create_padding_mask 的實(shí)現(xiàn)與作用
Q: create_padding_mask 在項(xiàng)目中實(shí)現(xiàn)在哪里?它的輸入是什么?有什么作用?
A: 實(shí)現(xiàn)在 transformer/create_mask.py(或文檔 docs/code?dasm.md 中的示例)。
def create_padding_mask(src, tgt, pad_idx=0):
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,src_len)
tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3) # (B,1,tgt_len,1)
tgt_len = tgt.size(1)
look_ahead_mask = torch.ones(tgt_len, tgt_len).tril().bool() \
.unsqueeze(0).unsqueeze(0) # (1,1,tgt_len,tgt_len)
tgt_mask = tgt_mask & look_ahead_mask.to(tgt.device)
return src_mask, tgt_mask
-
輸入:
src(源序列張量,形狀(batch, src_len)),tgt(目標(biāo)序列張量,形狀(batch, tgt_len)),以及可選的pad_idx(默認(rèn) 0)。 -
輸出:
src_mask(僅 padding mask,形狀(batch,1,1,src_len))和tgt_mask(padding?+?look?ahead mask,形狀(batch,1,tgt_len,tgt_len))。 - 作用:在注意力計(jì)算前屏蔽填充位置(防止無意義的 PAD 影響注意力),并在解碼器自注意力中加入因果遮罩,防止模型看到未來 token。
2?? src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2) 的意義
Q: 這行代碼具體做了什么?
A:
-
(src != pad_idx)生成布爾張量,True表示該位置不是 PAD,False表示是 PAD。 -
.unsqueeze(1)在第 1 維插入大小為 1 的維度,使形狀變?yōu)?(batch,1,src_len)。 -
.unsqueeze(2)再插入第 2 維,得到(batch,1,1,src_len)。
這樣得到的張量可以在注意力得分矩陣(batch, n_heads, seq_len, seq_len)上廣播,masked_fill(mask == 0, -1e9)將 PAD 位置的分?jǐn)?shù)設(shè)為極小值,確保 softmax 后權(quán)重幾乎為 0。
3?? Masked Multi?Head Attention 與 Encoder 的 Padding Mask
Q: 論文圖中 Encoder 的 Multi?Head Attention 沒有 “masked”,但仍需要 padding mask,這兩者是否沖突?
A: 不沖突。
- Encoder 只需要 padding mask 來屏蔽填充位置。
-
Decoder 自注意力 需要 masked Multi?Head Attention,即 padding mask + look?ahead mask,因?yàn)樗仨毞乐箍吹轿磥?token。
論文圖中只在 Decoder 的自注意力旁標(biāo)記 “masked”,是因?yàn)橹挥心抢镄枰蚬谡郑籈ncoder 的 padding mask 被視為實(shí)現(xiàn)細(xì)節(jié),未在圖中顯式標(biāo)出。
4?? Masked Language Modeling (MLM) 與 Masked Multi?Head Attention 的關(guān)系
Q: MLM 與 Masked Multi?Head Attention 有何聯(lián)系與區(qū)別?
A:
| 維度 | Masked Language Modeling | Masked Multi?Head Attention |
|---|---|---|
| 層級(jí) | 訓(xùn)練任務(wù)(數(shù)據(jù)預(yù)處理 & 損失) | 模型層(注意力計(jì)算) |
| 遮罩對(duì)象 | 輸入 token 被替換為 [MASK](真實(shí)詞被隱藏) |
注意力分?jǐn)?shù)矩陣的 位置 被屏蔽(padding?+?future) |
| 目的 | 學(xué)習(xí)雙向上下文表示,預(yù)測(cè)被遮蓋的詞 | 保證解碼器自回歸(因果)并屏蔽填充 |
| 實(shí)現(xiàn) | 在數(shù)據(jù) pipeline 中隨機(jī) mask token,計(jì)算交叉熵僅在 mask 位置 | 在 create_padding_mask 中生成布爾 mask,scaled_dot_product_attention 中 masked_fill 使用 |
兩者都使用 “mask” 這個(gè)概念,但 MLM 是 任務(wù)層面的 目標(biāo),MMHA 是 模型層面的 機(jī)制。
5?? create_padding_mask 與 MLM 的關(guān)聯(lián)
Q: create_padding_mask 與 MLM 有直接關(guān)聯(lián)嗎?
A: 沒有直接關(guān)聯(lián)。create_padding_mask 只處理 padding(填充)位置的遮罩,用于注意力層;MLM 處理 隨機(jī)遮蓋 token,用于預(yù)訓(xùn)練目標(biāo)。它們?cè)谕荒P椭锌梢怨泊妫ɡ?BERT 使用 create_padding_mask 防止注意力關(guān)注 PAD,同時(shí)在預(yù)訓(xùn)練階段使用 MLM),但實(shí)現(xiàn)上是獨(dú)立的。
6?? Transformer 前向傳播的完整步驟(對(duì)應(yīng)論文 Figure?2)
- Embedding + Positional Encoding(源 & 目標(biāo))
-
Mask 生成:調(diào)用
create_padding_mask(src, tgt)→src_mask,tgt_mask。 -
Encoder 堆疊(N 層)
- Multi?Head Self?Attention(使用
src_mask) - Add & LayerNorm
- Feed?Forward Network (FFN)
- Add & LayerNorm
- Multi?Head Self?Attention(使用
-
Decoder 堆疊(N 層)
-
Masked Multi?Head Self?Attention(使用
tgt_mask) - Add & LayerNorm
- Encoder?Decoder Attention(使用
src_mask) - Add & LayerNorm
- FFN
- Add & LayerNorm
-
Masked Multi?Head Self?Attention(使用
- 線性投影 + Softmax → 輸出 logits(詞表概率)。
7?? 前向傳播 vs. 前饋網(wǎng)絡(luò)(FFN)
- 前向傳播:模型整體的計(jì)算流程,從輸入到輸出 logits,包含嵌入、注意力、FFN、殘差、歸一化等所有子模塊。
-
前饋網(wǎng)絡(luò):在每個(gè) Encoder/Decoder 層內(nèi)部的子模塊,僅對(duì)每個(gè)位置的向量進(jìn)行兩層全連接變換(
Linear → ReLU → Linear),隨后加殘差和 LayerNorm。它是前向傳播中的一個(gè)組成部分,負(fù)責(zé)提供位置級(jí)的非線性表達(dá)能力。
8?? 前向傳播是否包含反向傳播?
Q: 前向傳播的過程中會(huì)包括反向傳播嗎?
A: 不會(huì)。前向傳播只負(fù)責(zé)計(jì)算模型輸出并緩存中間激活;反向傳播在計(jì)算完損失后通過 loss.backward() 觸發(fā),用于梯度計(jì)算和參數(shù)更新。訓(xùn)練循環(huán)典型順序?yàn)椋?/p>
optimizer.zero_grad()
output = model(src, tgt) # 前向傳播
loss = criterion(output, target) # 計(jì)算損失
loss.backward() # 反向傳播
optimizer.step() # 參數(shù)更新
推理時(shí)只執(zhí)行前向傳播,不會(huì)進(jìn)行反向傳播。
9?? 代碼位置概覽
| 功能 | 文件 |
|---|---|
create_padding_mask |
transformer/create_mask.py |
| Encoder Layer | transformer/Encoder.py |
| Decoder Layer | transformer/Decoder.py |
| Multi?Head Attention | transformer/MHA.py |
| Feed?Forward Network | transformer/FFN.py |
| 整體模型(forward) | transformer/model.py |
| 示例訓(xùn)練/測(cè)試 |
transformer/test.py、Transformer.ipynb
|
結(jié)語
以上即為本項(xiàng)目中關(guān)于 mask、attention、前向/前饋、MLM 與 MMHA 的完整問答匯總
附錄
Paper: https://arxiv.org/abs/1706.03762
Transformer?from?Scratch: https://github.com/Breeze648/Transformer-from-Scratch