本例子是訓(xùn)練了一個(gè) Transformer 模型 用于將葡萄牙語(yǔ)翻譯成英語(yǔ)。這是一個(gè)高級(jí)示例,假定您具備文本生成(text generation)和 注意力機(jī)制(attention) 的知識(shí)。
Transformer 模型的核心思想是自注意力機(jī)制(self-attention)——能注意輸入序列的不同位置以計(jì)算該序列的表示的能力。Transformer 創(chuàng)建了多層自注意力層(self-attetion layers)組成的堆棧,下文的按比縮放的點(diǎn)積注意力(Scaled dot product attention)和多頭注意力(Multi-head attention)部分對(duì)此進(jìn)行了說(shuō)明。
一個(gè) transformer 模型用自注意力層而非 RNNs 或 CNNs 來(lái)處理變長(zhǎng)的輸入。這種通用架構(gòu)有一系列的優(yōu)勢(shì):
它不對(duì)數(shù)據(jù)間的時(shí)間/空間關(guān)系做任何假設(shè)。這是處理一組對(duì)象(objects)的理想選擇(例如,星際爭(zhēng)霸單位(StarCraft units))。 層輸出可以并行計(jì)算,而非像 RNN 這樣的序列計(jì)算。 遠(yuǎn)距離項(xiàng)可以影響彼此的輸出,而無(wú)需經(jīng)過(guò)許多 RNN 步驟或卷積層(例如,參見(jiàn)場(chǎng)景記憶 Transformer(Scene Memory Transformer)) 它能學(xué)習(xí)長(zhǎng)距離的依賴(lài)。在許多序列任務(wù)中,這是一項(xiàng)挑戰(zhàn)。 該架構(gòu)的缺點(diǎn)是:
對(duì)于時(shí)間序列,一個(gè)單位時(shí)間的輸出是從整個(gè)歷史記錄計(jì)算的,而非僅從輸入和當(dāng)前的隱含狀態(tài)計(jì)算得到。這可能效率較低。 如果輸入確實(shí)有時(shí)間/空間的關(guān)系,像文本,則必須加入一些位置編碼,否則模型將有效地看到一堆單詞。 在此 notebook 中訓(xùn)練完模型后,您將能輸入葡萄牙語(yǔ)句子,得到其英文翻譯。
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
# 使用 TFDS 來(lái)導(dǎo)入 葡萄牙語(yǔ)-英語(yǔ)翻譯數(shù)據(jù)集,該數(shù)據(jù)集來(lái)自于 TED 演講開(kāi)放翻譯項(xiàng)目.
print("start")
# 該數(shù)據(jù)集包含來(lái)約 50000 條訓(xùn)練樣本,1100 條驗(yàn)證樣本,以及 2000 條測(cè)試樣本。
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True,
as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']
# 從訓(xùn)練數(shù)據(jù)集創(chuàng)建自定義子詞分詞器(subwords tokenizer)。
tokenizer_en = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus((en.numpy() for pt, en in train_examples), target_vocab_size=2**13)
tokenizer_pt = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus((pt.numpy() for pt, en in train_examples), target_vocab_size=2**13)
sample_string = 'Transformer is awesome.'
tokenized_string = tokenizer_en.encode(sample_string)
print ('Tokenized string is {}'.format(tokenized_string))
original_string = tokenizer_en.decode(tokenized_string)
print ('The original string: {}'.format(original_string))
assert original_string == sample_string
# 如果單詞不在詞典中,則分詞器(tokenizer)通過(guò)將單詞分解為子詞來(lái)對(duì)字符串進(jìn)行編碼。
for ts in tokenized_string:
print ('{} ----> {}'.format(ts, tokenizer_en.decode([ts])))
BUFFER_SIZE = 20000
BATCH_SIZE = 64
# 將開(kāi)始和結(jié)束標(biāo)記(token)添加到輸入和目標(biāo)。
def encode(lang1, lang2):
lang1 = [tokenizer_pt.vocab_size] + tokenizer_pt.encode(lang1.numpy()) + [tokenizer_pt.vocab_size+1]
lang2 = [tokenizer_en.vocab_size] + tokenizer_en.encode(lang2.numpy()) + [tokenizer_en.vocab_size+1]
return lang1, lang2
# Note:為了使本示例較小且相對(duì)較快,刪除長(zhǎng)度大于40個(gè)標(biāo)記的樣本。
MAX_LENGTH = 40
def filter_max_length(x, y, max_length=MAX_LENGTH):
return tf.logical_and(tf.size(x) <= max_length,tf.size(y) <= max_length)
# .map() 內(nèi)部的操作以圖模式(graph mode)運(yùn)行,.map() 接收一個(gè)不具有 numpy 屬性的圖張量(graph tensor)。
# 該分詞器(tokenizer)需要將一個(gè)字符串或 Unicode 符號(hào),編碼成整數(shù)。
# 因此,您需要在 tf.py_function 內(nèi)部運(yùn)行編碼過(guò)程,tf.py_function 接收一個(gè) eager 張量,
# 該 eager 張量有一個(gè)包含字符串值的 numpy 屬性。
def tf_encode(pt, en):
result_pt, result_en = tf.py_function(encode, [pt, en], [tf.int64, tf.int64])
result_pt.set_shape([None])
result_en.set_shape([None])
return result_pt, result_en
train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# 將數(shù)據(jù)集緩存到內(nèi)存中以加快讀取速度。
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
val_dataset = val_examples.map(tf_encode)
val_dataset = val_dataset.filter(filter_max_length).padded_batch(BATCH_SIZE)
pt_batch, en_batch = next(iter(val_dataset))
pt_batch, en_batch
# 因?yàn)樵撃P筒⒉话ㄈ魏蔚难h(huán)(recurrence)或卷積,所以模型添加了位置編碼,為模型提供一些關(guān)于單詞在句子中相對(duì)位置的信息。
# 位置編碼向量被加到嵌入(embedding)向量中。嵌入表示一個(gè) d 維空間的標(biāo)記,在 d 維空間中有著相似含義的標(biāo)記會(huì)離彼此更近。
# 但是,嵌入并沒(méi)有對(duì)在一句話(huà)中的詞的相對(duì)位置進(jìn)行編碼。
# 因此,當(dāng)加上位置編碼后,詞將基于它們含義的相似度以及它們?cè)诰渥又械奈恢?,?d 維空間中離彼此更近。
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],np.arange(d_model)[np.newaxis, :],d_model)
# 將 sin 應(yīng)用于數(shù)組中的偶數(shù)索引(indices);2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# 將 cos 應(yīng)用于數(shù)組中的奇數(shù)索引;2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
pos_encoding = positional_encoding(50, 512)
print (pos_encoding.shape)
plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('Depth')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
# 遮擋(Masking)
# 遮擋一批序列中所有的填充標(biāo)記(pad tokens)。
# 這確保了模型不會(huì)將填充作為輸入。該 mask 表明填充值 0 出現(xiàn)的位置:在這些位置 mask 輸出 1,否則輸出 0。
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
# 添加額外的維度來(lái)將填充加到
# 注意力對(duì)數(shù)(logits)。
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
# 前瞻遮擋(look-ahead mask)用于遮擋一個(gè)序列中的后續(xù)標(biāo)記(future tokens)。換句話(huà)說(shuō),該 mask 表明了不應(yīng)該使用的條目。
# 這意味著要預(yù)測(cè)第三個(gè)詞,將僅使用第一個(gè)和第二個(gè)詞。與此類(lèi)似,預(yù)測(cè)第四個(gè)詞,僅使用第一個(gè),第二個(gè)和第三個(gè)詞,依此類(lèi)推。
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
start
Tokenized string is [7915, 1248, 7946, 7194, 13, 2799, 7877]
The original string: Transformer is awesome.
7915 ----> T
1248 ----> ran
7946 ----> s
7194 ----> former
13 ----> is
2799 ----> awesome
7877 ----> .
(1, 50, 512)

Output:
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
[0., 0., 1.],
[0., 0., 0.]], dtype=float32)>

Transformer 使用的注意力函數(shù)有三個(gè)輸入:Q(請(qǐng)求(query))、K(主鍵(key))、V(數(shù)值(value))
def scaled_dot_product_attention(q, k, v, mask):
# """
# 計(jì)算注意力權(quán)重。
# q, k, v 必須具有匹配的前置維度。
# k, v 必須有匹配的倒數(shù)第二個(gè)維度,例如:seq_len_k = seq_len_v。
# 雖然 mask 根據(jù)其類(lèi)型(填充或前瞻)有不同的形狀,
# 但是 mask 必須能進(jìn)行廣播轉(zhuǎn)換以便求和。
# 參數(shù):
# q: 請(qǐng)求的形狀 == (..., seq_len_q, depth)
# k: 主鍵的形狀 == (..., seq_len_k, depth)
# v: 數(shù)值的形狀 == (..., seq_len_v, depth_v)
# mask: Float 張量,其形狀能轉(zhuǎn)換成
# (..., seq_len_q, seq_len_k)。默認(rèn)為None。
# 返回值:
# 輸出,注意力權(quán)重
# """
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# 縮放 matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# 將 mask 加入到縮放的張量上。
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax 在最后一個(gè)軸(seq_len_k)上歸一化,因此分?jǐn)?shù)
# 相加等于1。
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
# 當(dāng) softmax 在 K 上進(jìn)行歸一化后,它的值決定了分配到 Q 的重要程度。
# 輸出表示注意力權(quán)重和 V(數(shù)值)向量的乘積。這確保了要關(guān)注的詞保持原樣,而無(wú)關(guān)的詞將被清除掉。
def print_out(q, k, v):
temp_out, temp_attn = scaled_dot_product_attention(q, k, v, None)
print ('Attention weights are:')
print (temp_attn)
print ('Output is:')
print (temp_out)
np.set_printoptions(suppress=True)
temp_k = tf.constant([[10,0,0],
[0,10,0],
[0,0,10],
[0,0,10]], dtype=tf.float32) # (4, 3)
temp_v = tf.constant([[ 1,0],
[ 10,0],
[ 100,5],
[1000,6]], dtype=tf.float32) # (4, 2)
# 這條 `請(qǐng)求(query)符合第二個(gè)`主鍵(key)`,
# 因此返回了第二個(gè)`數(shù)值(value)`。
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
# 這條請(qǐng)求符合重復(fù)出現(xiàn)的主鍵(第三第四個(gè)),
# 因此,對(duì)所有的相關(guān)數(shù)值取了平均。
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32) # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)
Attention weights are:
tf.Tensor([[0. 0. 0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550. 5.5]], shape=(1, 2), dtype=float32)
Attention weights are:
tf.Tensor(
[[0. 0. 0.5 0.5]
[0. 1. 0. 0. ]
[0.5 0.5 0. 0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550. 5.5]
[ 10. 0. ]
[ 5.5 0. ]], shape=(3, 2), dtype=float32)

多頭注意力由四部分組成:
線(xiàn)性層并分拆成多頭。
按比縮放的點(diǎn)積注意力。
多頭及聯(lián)。
最后一層線(xiàn)性層。
每個(gè)多頭注意力塊有三個(gè)輸入:Q(請(qǐng)求)、K(主鍵)、V(數(shù)值)。這些輸入經(jīng)過(guò)線(xiàn)性(Dense)層,并分拆成多頭。
將上面定義的 scaled_dot_product_attention 函數(shù)應(yīng)用于每個(gè)頭(進(jìn)行了廣播(broadcasted)以提高效率)。注意力這步必須使用一個(gè)恰當(dāng)?shù)?mask。然后將每個(gè)頭的注意力輸出連接起來(lái)(用tf.transpose 和 tf.reshape),并放入最后的 Dense 層。
Q、K、和 V 被拆分到了多個(gè)頭,而非單個(gè)的注意力頭,因?yàn)槎囝^允許模型共同注意來(lái)自不同表示空間的不同位置的信息。在分拆后,每個(gè)頭部的維度減少,因此總的計(jì)算成本與有著全部維度的單個(gè)注意力頭相同。
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
# """分拆最后一個(gè)維度到 (num_heads, depth).
# 轉(zhuǎn)置結(jié)果使得形狀為 (batch_size, num_heads, seq_len, depth)
# """
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
# 創(chuàng)建一個(gè) MultiHeadAttention 層進(jìn)行嘗試。
# 在序列中的每個(gè)位置 y,MultiHeadAttention 在序列中的所有其他位置運(yùn)行所有8個(gè)注意力頭,
# 在每個(gè)位置y,返回一個(gè)新的同樣長(zhǎng)度的向量。
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512)) # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
# 點(diǎn)式前饋網(wǎng)絡(luò)(Point wise feed forward network)
# 點(diǎn)式前饋網(wǎng)絡(luò)由兩層全聯(lián)接層組成,兩層之間有一個(gè) ReLU 激活函數(shù)。
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
sample_ffn = point_wise_feed_forward_network(512, 2048)
sample_ffn(tf.random.uniform((64, 50, 512))).shape
Output:
TensorShape([64, 50, 512])

Transformer 模型與標(biāo)準(zhǔn)的具有注意力機(jī)制的序列到序列模型(sequence to sequence with attention model),遵循相同的一般模式。
輸入語(yǔ)句經(jīng)過(guò) N 個(gè)編碼器層,為序列中的每個(gè)詞/標(biāo)記生成一個(gè)輸出。 解碼器關(guān)注編碼器的輸出以及它自身的輸入(自注意力)來(lái)預(yù)測(cè)下一個(gè)詞。 編碼器層(Encoder layer) 每個(gè)編碼器層包括以下子層:
多頭注意力(有填充遮擋) 點(diǎn)式前饋網(wǎng)絡(luò)(Point wise feed forward networks)。 每個(gè)子層在其周?chē)幸粋€(gè)殘差連接,然后進(jìn)行層歸一化。殘差連接有助于避免深度網(wǎng)絡(luò)中的梯度消失問(wèn)題。
每個(gè)子層的輸出是 LayerNorm(x + Sublayer(x))。歸一化是在 d_model(最后一個(gè))維度完成的。Transformer 中有 N 個(gè)編碼器層。
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
return out2
sample_encoder_layer = EncoderLayer(512, 8, 2048)
sample_encoder_layer_output = sample_encoder_layer(tf.random.uniform((64, 43, 512)), False, None)
sample_encoder_layer_output.shape # (batch_size, input_seq_len, d_model)
Output:
TensorShape([64, 43, 512])
解碼器層(Decoder layer) 每個(gè)解碼器層包括以下子層:
遮擋的多頭注意力(前瞻遮擋和填充遮擋) 多頭注意力(用填充遮擋)。V(數(shù)值)和 K(主鍵)接收編碼器輸出作為輸入。Q(請(qǐng)求)接收遮擋的多頭注意力子層的輸出。 點(diǎn)式前饋網(wǎng)絡(luò) 每個(gè)子層在其周?chē)幸粋€(gè)殘差連接,然后進(jìn)行層歸一化。每個(gè)子層的輸出是 LayerNorm(x + Sublayer(x))。歸一化是在 d_model(最后一個(gè))維度完成的。
Transformer 中共有 N 個(gè)解碼器層。
當(dāng) Q 接收到解碼器的第一個(gè)注意力塊的輸出,并且 K 接收到編碼器的輸出時(shí),注意力權(quán)重表示根據(jù)編碼器的輸出賦予解碼器輸入的重要性。換一種說(shuō)法,解碼器通過(guò)查看編碼器輸出和對(duì)其自身輸出的自注意力,預(yù)測(cè)下一個(gè)詞。參看按比縮放的點(diǎn)積注意力部分的演示。
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
# enc_output.shape == (batch_size, input_seq_len, d_model)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2
sample_decoder_layer = DecoderLayer(512, 8, 2048)
sample_decoder_layer_output, _, _ = sample_decoder_layer(tf.random.uniform((64, 50, 512)), sample_encoder_layer_output, False, None, None)
sample_decoder_layer_output.shape # (batch_size, target_seq_len, d_model)
Output:
TensorShape([64, 50, 512])
# 編碼器(Encoder)
# 編碼器 包括:
# 輸入嵌入(Input Embedding)
# 位置編碼(Positional Encoding)
# N 個(gè)編碼器層(encoder layers)
# 輸入經(jīng)過(guò)嵌入(embedding)后,該嵌入與位置編碼相加。該加法結(jié)果的輸出是編碼器層的輸入。編碼器的輸出是解碼器的輸入。
class Encoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding,
self.d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
# 將嵌入和位置編碼相加。
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x # (batch_size, input_seq_len, d_model)
sample_encoder = Encoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, input_vocab_size=8500,
maximum_position_encoding=10000)
sample_encoder_output = sample_encoder(tf.random.uniform((64, 62)),
training=False, mask=None)
print (sample_encoder_output.shape) # (batch_size, input_seq_len, d_model)
# 解碼器(Decoder)
# 解碼器包括:
# 輸出嵌入(Output Embedding)
# 位置編碼(Positional Encoding)
# N 個(gè)解碼器層(decoder layers)
# 目標(biāo)(target)經(jīng)過(guò)一個(gè)嵌入后,該嵌入和位置編碼相加。該加法結(jié)果是解碼器層的輸入。解碼器的輸出是最后的線(xiàn)性層的輸入。
class Decoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training,look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
# x.shape == (batch_size, target_seq_len, d_model)
return x, attention_weights
sample_decoder = Decoder(num_layers=2, d_model=512, num_heads=8,
dff=2048, target_vocab_size=8000,
maximum_position_encoding=5000)
output, attn = sample_decoder(tf.random.uniform((64, 26)),
enc_output=sample_encoder_output,
training=False, look_ahead_mask=None,
padding_mask=None)
output.shape, attn['decoder_layer2_block2'].shape
Output:
(64, 62, 512)
(TensorShape([64, 26, 512]), TensorShape([64, 8, 26, 62]))
創(chuàng)建 Transformer
Transformer 包括編碼器,解碼器和最后的線(xiàn)性層。解碼器的輸出是線(xiàn)性層的輸入,返回線(xiàn)性層的輸出。
class Transformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, pe_input, rate)
self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
enc_output = self.encoder(inp, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model)
# dec_output.shape == (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
return final_output, attention_weights
sample_transformer = Transformer(
num_layers=2, d_model=512, num_heads=8, dff=2048,
input_vocab_size=8500, target_vocab_size=8000,
pe_input=10000, pe_target=6000)
temp_input = tf.random.uniform((64, 62))
temp_target = tf.random.uniform((64, 26))
fn_out, _ = sample_transformer(temp_input, temp_target, training=False,
enc_padding_mask=None,
look_ahead_mask=None,
dec_padding_mask=None)
fn_out.shape # (batch_size, tar_seq_len, target_vocab_size)
# 配置超參數(shù)(hyperparameters)
# 為了讓本示例小且相對(duì)較快,已經(jīng)減小了num_layers、 d_model 和 dff 的值。
# Transformer 的基礎(chǔ)模型使用的數(shù)值為:num_layers=6,d_model = 512,dff = 2048。關(guān)于所有其他版本的 Transformer,請(qǐng)查閱論文。
# Note:通過(guò)改變以下數(shù)值,您可以獲得在許多任務(wù)上達(dá)到最先進(jìn)水平的模型。
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
input_vocab_size = tokenizer_pt.vocab_size + 2
target_vocab_size = tokenizer_en.vocab_size + 2
dropout_rate = 0.1
優(yōu)化器(Optimizer)
根據(jù)論文中的公式,將 Adam 優(yōu)化器與自定義的學(xué)習(xí)速率調(diào)度程序(scheduler)配合使用。
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
epsilon=1e-9)
temp_learning_rate_schedule = CustomSchedule(d_model)
plt.plot(temp_learning_rate_schedule(tf.range(40000, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
Output:
Text(0.5, 0, 'Train Step')

損失函數(shù)與指標(biāo)(Loss and metrics)
由于目標(biāo)序列是填充(padded)過(guò)的,因此在計(jì)算損失函數(shù)時(shí),應(yīng)用填充遮擋非常重要。
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss_function(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
訓(xùn)練與檢查點(diǎn)(Training and checkpointing)
transformer = Transformer(num_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size,
pe_input=input_vocab_size,
pe_target=target_vocab_size,
rate=dropout_rate)
def create_masks(inp, tar):
# 編碼器填充遮擋
enc_padding_mask = create_padding_mask(inp)
# 在解碼器的第二個(gè)注意力模塊使用。
# 該填充遮擋用于遮擋編碼器的輸出。
dec_padding_mask = create_padding_mask(inp)
# 在解碼器的第一個(gè)注意力模塊使用。
# 用于填充(pad)和遮擋(mask)解碼器獲取到的輸入的后續(xù)標(biāo)記(future tokens)。
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = create_padding_mask(tar)
combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
# 創(chuàng)建檢查點(diǎn)的路徑和檢查點(diǎn)管理器(manager)。這將用于在每 n 個(gè)周期(epochs)保存檢查點(diǎn)。
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(transformer=transformer,optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# 如果檢查點(diǎn)存在,則恢復(fù)最新的檢查點(diǎn)。
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
目標(biāo)(target)被分成了 tar_inp 和 tar_real。tar_inp 作為輸入傳遞到解碼器。tar_real 是位移了 1 的同一個(gè)輸入:在 tar_inp 中的每個(gè)位置,tar_real 包含了應(yīng)該被預(yù)測(cè)到的下一個(gè)標(biāo)記(token)。
例如,sentence = "SOS A lion in the jungle is sleeping EOS"
tar_inp = "SOS A lion in the jungle is sleeping"
tar_real = "A lion in the jungle is sleeping EOS"
Transformer 是一個(gè)自回歸(auto-regressive)模型:它一次作一個(gè)部分的預(yù)測(cè),然后使用到目前為止的自身的輸出來(lái)決定下一步要做什么。
在訓(xùn)練過(guò)程中,本示例使用了 teacher-forcing 的方法(就像文本生成教程中一樣)。無(wú)論模型在當(dāng)前時(shí)間步驟下預(yù)測(cè)出什么,teacher-forcing 方法都會(huì)將真實(shí)的輸出傳遞到下一個(gè)時(shí)間步驟上。
當(dāng) transformer 預(yù)測(cè)每個(gè)詞時(shí),自注意力(self-attention)功能使它能夠查看輸入序列中前面的單詞,從而更好地預(yù)測(cè)下一個(gè)單詞。
為了防止模型在期望的輸出上達(dá)到峰值,模型使用了前瞻遮擋(look-ahead mask)。
EPOCHS = 20
# 該 @tf.function 將追蹤-編譯 train_step 到 TF 圖中,以便更快地
# 執(zhí)行。該函數(shù)專(zhuān)用于參數(shù)張量的精確形狀。為了避免由于可變序列長(zhǎng)度或可變
# 批次大?。ㄗ詈笠慌屋^?。?dǎo)致的再追蹤,使用 input_signature 指定
# 更多的通用形狀。
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape() as tape:
predictions, _ = transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(tar_real, predictions)
葡萄牙語(yǔ)作為輸入語(yǔ)言,英語(yǔ)為目標(biāo)語(yǔ)言
for epoch in range(EPOCHS):
start = time.time()
train_loss.reset_states()
train_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (inp, tar)) in enumerate(train_dataset):
train_step(inp, tar)
if batch % 50 == 0:
print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, train_loss.result(), train_accuracy.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1,
train_loss.result(),
train_accuracy.result()))
print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
Output:
Epoch 1 Batch 0 Loss 0.2771 Accuracy 0.4265
Epoch 1 Batch 50 Loss 0.2988 Accuracy 0.3804
Epoch 1 Batch 100 Loss 0.3013 Accuracy 0.3820
Epoch 1 Batch 150 Loss 0.3064 Accuracy 0.3844
Epoch 1 Batch 200 Loss 0.3108 Accuracy 0.3844
Epoch 1 Batch 250 Loss 0.3136 Accuracy 0.3835
Epoch 1 Batch 300 Loss 0.3158 Accuracy 0.3831
Epoch 1 Batch 350 Loss 0.3184 Accuracy 0.3833
Epoch 1 Batch 400 Loss 0.3217 Accuracy 0.3833
Epoch 1 Batch 450 Loss 0.3245 Accuracy 0.3835
Epoch 1 Batch 500 Loss 0.3267 Accuracy 0.3835
Epoch 1 Batch 550 Loss 0.3288 Accuracy 0.3829
Epoch 1 Batch 600 Loss 0.3317 Accuracy 0.3830
Epoch 1 Batch 650 Loss 0.3348 Accuracy 0.3831
Epoch 1 Batch 700 Loss 0.3368 Accuracy 0.3831
Epoch 1 Loss 0.3369 Accuracy 0.3830
Time taken for 1 epoch: 837.5033097267151 secs
Epoch 2 Batch 0 Loss 0.3254 Accuracy 0.4145
Epoch 2 Batch 50 Loss 0.2883 Accuracy 0.3898
Epoch 2 Batch 100 Loss 0.2977 Accuracy 0.3925
Epoch 2 Batch 150 Loss 0.3007 Accuracy 0.3890
Epoch 2 Batch 200 Loss 0.3046 Accuracy 0.3876
Epoch 2 Batch 250 Loss 0.3085 Accuracy 0.3863
Epoch 2 Batch 300 Loss 0.3115 Accuracy 0.3865
Epoch 2 Batch 350 Loss 0.3148 Accuracy 0.3861
Epoch 2 Batch 400 Loss 0.3175 Accuracy 0.3865
Epoch 2 Batch 450 Loss 0.3204 Accuracy 0.3862
Epoch 2 Batch 500 Loss 0.3235 Accuracy 0.3859
Epoch 2 Batch 550 Loss 0.3256 Accuracy 0.3854
Epoch 2 Batch 600 Loss 0.3281 Accuracy 0.3847
Epoch 2 Batch 650 Loss 0.3300 Accuracy 0.3843
Epoch 2 Batch 700 Loss 0.3321 Accuracy 0.3842
Epoch 2 Loss 0.3322 Accuracy 0.3841
Time taken for 1 epoch: 837.3874278068542 secs
Epoch 3 Batch 0 Loss 0.3313 Accuracy 0.4124
Epoch 3 Batch 50 Loss 0.2948 Accuracy 0.3916
Epoch 3 Batch 100 Loss 0.2949 Accuracy 0.3898
Epoch 3 Batch 150 Loss 0.3016 Accuracy 0.3910
Epoch 3 Batch 200 Loss 0.3029 Accuracy 0.3912
Epoch 3 Batch 250 Loss 0.3038 Accuracy 0.3890
Epoch 3 Batch 300 Loss 0.3069 Accuracy 0.3891
Epoch 3 Batch 350 Loss 0.3097 Accuracy 0.3886
Epoch 3 Batch 400 Loss 0.3116 Accuracy 0.3878
Epoch 3 Batch 450 Loss 0.3142 Accuracy 0.3874
Epoch 3 Batch 500 Loss 0.3165 Accuracy 0.3870
Epoch 3 Batch 550 Loss 0.3185 Accuracy 0.3867
Epoch 3 Batch 600 Loss 0.3208 Accuracy 0.3863
Epoch 3 Batch 650 Loss 0.3230 Accuracy 0.3857
Epoch 3 Batch 700 Loss 0.3259 Accuracy 0.3857
Epoch 3 Loss 0.3259 Accuracy 0.3857
Time taken for 1 epoch: 840.7752561569214 secs
Epoch 4 Batch 0 Loss 0.2586 Accuracy 0.4231
Epoch 4 Batch 50 Loss 0.2878 Accuracy 0.3933
Epoch 4 Batch 100 Loss 0.2916 Accuracy 0.3928
Epoch 4 Batch 150 Loss 0.2972 Accuracy 0.3931
Epoch 4 Batch 200 Loss 0.2995 Accuracy 0.3905
Epoch 4 Batch 250 Loss 0.3018 Accuracy 0.3910
Epoch 4 Batch 300 Loss 0.3050 Accuracy 0.3907
Epoch 4 Batch 350 Loss 0.3073 Accuracy 0.3907
Epoch 4 Batch 400 Loss 0.3092 Accuracy 0.3907
Epoch 4 Batch 450 Loss 0.3103 Accuracy 0.3897
Epoch 4 Batch 500 Loss 0.3114 Accuracy 0.3889
Epoch 4 Batch 550 Loss 0.3145 Accuracy 0.3882
Epoch 4 Batch 600 Loss 0.3173 Accuracy 0.3882
Epoch 4 Batch 650 Loss 0.3187 Accuracy 0.3872
Epoch 4 Batch 700 Loss 0.3209 Accuracy 0.3865
Epoch 4 Loss 0.3209 Accuracy 0.3865
Time taken for 1 epoch: 841.4032762050629 secs
Epoch 5 Batch 0 Loss 0.2532 Accuracy 0.4337
Epoch 5 Batch 50 Loss 0.2822 Accuracy 0.3921
Epoch 5 Batch 100 Loss 0.2859 Accuracy 0.3926
Epoch 5 Batch 150 Loss 0.2915 Accuracy 0.3910
Epoch 5 Batch 200 Loss 0.2936 Accuracy 0.3907
Epoch 5 Batch 250 Loss 0.2952 Accuracy 0.3902
Epoch 5 Batch 300 Loss 0.2991 Accuracy 0.3903
Epoch 5 Batch 350 Loss 0.3015 Accuracy 0.3902
Epoch 5 Batch 400 Loss 0.3040 Accuracy 0.3897
Epoch 5 Batch 450 Loss 0.3066 Accuracy 0.3898
Epoch 5 Batch 500 Loss 0.3089 Accuracy 0.3898
Epoch 5 Batch 550 Loss 0.3113 Accuracy 0.3897
Epoch 5 Batch 600 Loss 0.3137 Accuracy 0.3898
Epoch 5 Batch 650 Loss 0.3154 Accuracy 0.3892
Epoch 5 Batch 700 Loss 0.3172 Accuracy 0.3883
Saving checkpoint for epoch 5 at ./checkpoints/train/ckpt-9
Epoch 5 Loss 0.3172 Accuracy 0.3883
Time taken for 1 epoch: 836.7346789836884 secs
Epoch 6 Batch 0 Loss 0.2884 Accuracy 0.4005
Epoch 6 Batch 50 Loss 0.2729 Accuracy 0.3872
Epoch 6 Batch 100 Loss 0.2780 Accuracy 0.3909
Epoch 6 Batch 150 Loss 0.2800 Accuracy 0.3895
Epoch 6 Batch 200 Loss 0.2827 Accuracy 0.3895
Epoch 6 Batch 250 Loss 0.2860 Accuracy 0.3883
Epoch 6 Batch 300 Loss 0.2906 Accuracy 0.3886
Epoch 6 Batch 350 Loss 0.2938 Accuracy 0.3887
Epoch 6 Batch 400 Loss 0.2964 Accuracy 0.3881
Epoch 6 Batch 450 Loss 0.2989 Accuracy 0.3884
Epoch 6 Batch 500 Loss 0.3017 Accuracy 0.3888
Epoch 6 Batch 550 Loss 0.3044 Accuracy 0.3887
Epoch 6 Batch 600 Loss 0.3066 Accuracy 0.3881
Epoch 6 Batch 650 Loss 0.3090 Accuracy 0.3876
Epoch 6 Batch 700 Loss 0.3115 Accuracy 0.3873
Epoch 6 Loss 0.3116 Accuracy 0.3873
Time taken for 1 epoch: 838.8425750732422 secs
Epoch 7 Batch 0 Loss 0.2667 Accuracy 0.4269
Epoch 7 Batch 50 Loss 0.2697 Accuracy 0.3944
Epoch 7 Batch 100 Loss 0.2749 Accuracy 0.3937
Epoch 7 Batch 150 Loss 0.2792 Accuracy 0.3933
Epoch 7 Batch 200 Loss 0.2845 Accuracy 0.3947
Epoch 7 Batch 250 Loss 0.2876 Accuracy 0.3944
Epoch 7 Batch 300 Loss 0.2894 Accuracy 0.3942
Epoch 7 Batch 350 Loss 0.2917 Accuracy 0.3928
Epoch 7 Batch 400 Loss 0.2940 Accuracy 0.3922
Epoch 7 Batch 450 Loss 0.2963 Accuracy 0.3914
Epoch 7 Batch 500 Loss 0.2994 Accuracy 0.3914
Epoch 7 Batch 550 Loss 0.3008 Accuracy 0.3902
Epoch 7 Batch 600 Loss 0.3033 Accuracy 0.3901
Epoch 7 Batch 650 Loss 0.3055 Accuracy 0.3899
Epoch 7 Batch 700 Loss 0.3076 Accuracy 0.3893
Epoch 7 Loss 0.3077 Accuracy 0.3893
Time taken for 1 epoch: 837.0950720310211 secs
Epoch 8 Batch 0 Loss 0.2953 Accuracy 0.4106
Epoch 8 Batch 50 Loss 0.2771 Accuracy 0.3970
Epoch 8 Batch 100 Loss 0.2785 Accuracy 0.3982
Epoch 8 Batch 150 Loss 0.2803 Accuracy 0.3957
Epoch 8 Batch 200 Loss 0.2821 Accuracy 0.3946
Epoch 8 Batch 250 Loss 0.2834 Accuracy 0.3939
Epoch 8 Batch 300 Loss 0.2872 Accuracy 0.3944
Epoch 8 Batch 350 Loss 0.2902 Accuracy 0.3943
Epoch 8 Batch 400 Loss 0.2913 Accuracy 0.3933
Epoch 8 Batch 450 Loss 0.2928 Accuracy 0.3931
Epoch 8 Batch 500 Loss 0.2956 Accuracy 0.3934
Epoch 8 Batch 550 Loss 0.2975 Accuracy 0.3928
Epoch 8 Batch 600 Loss 0.2992 Accuracy 0.3920
Epoch 8 Batch 650 Loss 0.3011 Accuracy 0.3917
Epoch 8 Batch 700 Loss 0.3032 Accuracy 0.3908
Epoch 8 Loss 0.3034 Accuracy 0.3908
Time taken for 1 epoch: 836.5397372245789 secs
Epoch 9 Batch 0 Loss 0.2577 Accuracy 0.4107
Epoch 9 Batch 50 Loss 0.2686 Accuracy 0.3992
Epoch 9 Batch 100 Loss 0.2720 Accuracy 0.3956
Epoch 9 Batch 150 Loss 0.2751 Accuracy 0.3961
Epoch 9 Batch 200 Loss 0.2776 Accuracy 0.3971
Epoch 9 Batch 250 Loss 0.2796 Accuracy 0.3960
Epoch 9 Batch 300 Loss 0.2819 Accuracy 0.3955
Epoch 9 Batch 350 Loss 0.2842 Accuracy 0.3950
Epoch 9 Batch 400 Loss 0.2871 Accuracy 0.3950
Epoch 9 Batch 450 Loss 0.2889 Accuracy 0.3947
Epoch 9 Batch 500 Loss 0.2914 Accuracy 0.3948
Epoch 9 Batch 550 Loss 0.2935 Accuracy 0.3945
Epoch 9 Batch 600 Loss 0.2952 Accuracy 0.3937
Epoch 9 Batch 650 Loss 0.2972 Accuracy 0.3928
Epoch 9 Batch 700 Loss 0.2992 Accuracy 0.3920
Epoch 9 Loss 0.2991 Accuracy 0.3919
Time taken for 1 epoch: 836.3204340934753 secs
Epoch 10 Batch 0 Loss 0.2756 Accuracy 0.4359
Epoch 10 Batch 50 Loss 0.2647 Accuracy 0.4020
Epoch 10 Batch 100 Loss 0.2683 Accuracy 0.3995
Epoch 10 Batch 150 Loss 0.2706 Accuracy 0.3976
Epoch 10 Batch 200 Loss 0.2720 Accuracy 0.3970
Epoch 10 Batch 250 Loss 0.2737 Accuracy 0.3956
Epoch 10 Batch 300 Loss 0.2764 Accuracy 0.3950
Epoch 10 Batch 350 Loss 0.2782 Accuracy 0.3947
Epoch 10 Batch 400 Loss 0.2808 Accuracy 0.3947
Epoch 10 Batch 450 Loss 0.2828 Accuracy 0.3943
Epoch 10 Batch 500 Loss 0.2855 Accuracy 0.3941
Epoch 10 Batch 550 Loss 0.2883 Accuracy 0.3941
Epoch 10 Batch 600 Loss 0.2905 Accuracy 0.3935
Epoch 10 Batch 650 Loss 0.2921 Accuracy 0.3925
Epoch 10 Batch 700 Loss 0.2947 Accuracy 0.3924
Saving checkpoint for epoch 10 at ./checkpoints/train/ckpt-10
Epoch 10 Loss 0.2948 Accuracy 0.3924
Time taken for 1 epoch: 838.1140463352203 secs
Epoch 11 Batch 0 Loss 0.2054 Accuracy 0.3631
Epoch 11 Batch 50 Loss 0.2661 Accuracy 0.4077
Epoch 11 Batch 100 Loss 0.2647 Accuracy 0.4024
Epoch 11 Batch 150 Loss 0.2671 Accuracy 0.3988
Epoch 11 Batch 200 Loss 0.2703 Accuracy 0.3996
Epoch 11 Batch 250 Loss 0.2722 Accuracy 0.3981
Epoch 11 Batch 300 Loss 0.2742 Accuracy 0.3978
Epoch 11 Batch 350 Loss 0.2762 Accuracy 0.3964
Epoch 11 Batch 400 Loss 0.2793 Accuracy 0.3962
Epoch 11 Batch 450 Loss 0.2804 Accuracy 0.3962
Epoch 11 Batch 500 Loss 0.2820 Accuracy 0.3953
Epoch 11 Batch 550 Loss 0.2850 Accuracy 0.3954
Epoch 11 Batch 600 Loss 0.2871 Accuracy 0.3946
Epoch 11 Batch 650 Loss 0.2892 Accuracy 0.3939
Epoch 11 Batch 700 Loss 0.2910 Accuracy 0.3935
Epoch 11 Loss 0.2911 Accuracy 0.3934
Time taken for 1 epoch: 836.5090510845184 secs
Epoch 12 Batch 0 Loss 0.2596 Accuracy 0.3818
Epoch 12 Batch 50 Loss 0.2550 Accuracy 0.3932
Epoch 12 Batch 100 Loss 0.2594 Accuracy 0.3966
Epoch 12 Batch 150 Loss 0.2616 Accuracy 0.3954
Epoch 12 Batch 200 Loss 0.2649 Accuracy 0.3951
Epoch 12 Batch 250 Loss 0.2688 Accuracy 0.3957
Epoch 12 Batch 300 Loss 0.2704 Accuracy 0.3960
Epoch 12 Batch 350 Loss 0.2734 Accuracy 0.3961
Epoch 12 Batch 400 Loss 0.2754 Accuracy 0.3957
Epoch 12 Batch 450 Loss 0.2779 Accuracy 0.3956
Epoch 12 Batch 500 Loss 0.2799 Accuracy 0.3952
Epoch 12 Batch 550 Loss 0.2813 Accuracy 0.3945
Epoch 12 Batch 600 Loss 0.2836 Accuracy 0.3944
Epoch 12 Batch 650 Loss 0.2863 Accuracy 0.3945
Epoch 12 Batch 700 Loss 0.2880 Accuracy 0.3940
Epoch 12 Loss 0.2880 Accuracy 0.3940
Time taken for 1 epoch: 835.4420788288116 secs
Epoch 13 Batch 0 Loss 0.2850 Accuracy 0.4202
Epoch 13 Batch 50 Loss 0.2511 Accuracy 0.4039
Epoch 13 Batch 100 Loss 0.2531 Accuracy 0.3988
Epoch 13 Batch 150 Loss 0.2586 Accuracy 0.3981
Epoch 13 Batch 200 Loss 0.2610 Accuracy 0.3965
Epoch 13 Batch 250 Loss 0.2639 Accuracy 0.3975
Epoch 13 Batch 300 Loss 0.2671 Accuracy 0.3971
Epoch 13 Batch 350 Loss 0.2692 Accuracy 0.3974
Epoch 13 Batch 400 Loss 0.2715 Accuracy 0.3968
Epoch 13 Batch 450 Loss 0.2733 Accuracy 0.3962
Epoch 13 Batch 500 Loss 0.2757 Accuracy 0.3963
Epoch 13 Batch 550 Loss 0.2778 Accuracy 0.3962
Epoch 13 Batch 600 Loss 0.2798 Accuracy 0.3960
Epoch 13 Batch 650 Loss 0.2820 Accuracy 0.3957
Epoch 13 Batch 700 Loss 0.2838 Accuracy 0.3951
Epoch 13 Loss 0.2838 Accuracy 0.3950
Time taken for 1 epoch: 834.2117850780487 secs
Epoch 14 Batch 0 Loss 0.2444 Accuracy 0.4170
Epoch 14 Batch 50 Loss 0.2541 Accuracy 0.4027
Epoch 14 Batch 100 Loss 0.2541 Accuracy 0.3988
Epoch 14 Batch 150 Loss 0.2576 Accuracy 0.4016
Epoch 14 Batch 200 Loss 0.2616 Accuracy 0.4000
Epoch 14 Batch 250 Loss 0.2624 Accuracy 0.3992
Epoch 14 Batch 300 Loss 0.2647 Accuracy 0.3986
Epoch 14 Batch 350 Loss 0.2675 Accuracy 0.3983
Epoch 14 Batch 400 Loss 0.2688 Accuracy 0.3980
Epoch 14 Batch 450 Loss 0.2705 Accuracy 0.3976
Epoch 14 Batch 500 Loss 0.2725 Accuracy 0.3973
Epoch 14 Batch 550 Loss 0.2745 Accuracy 0.3970
Epoch 14 Batch 600 Loss 0.2767 Accuracy 0.3968
Epoch 14 Batch 650 Loss 0.2789 Accuracy 0.3961
Epoch 14 Batch 700 Loss 0.2811 Accuracy 0.3960
Epoch 14 Loss 0.2811 Accuracy 0.3960
Time taken for 1 epoch: 840.4115641117096 secs
Epoch 15 Batch 0 Loss 0.2785 Accuracy 0.3988
Epoch 15 Batch 50 Loss 0.2486 Accuracy 0.4043
Epoch 15 Batch 100 Loss 0.2527 Accuracy 0.4026
Epoch 15 Batch 150 Loss 0.2571 Accuracy 0.4022
Epoch 15 Batch 200 Loss 0.2567 Accuracy 0.3995
Epoch 15 Batch 250 Loss 0.2592 Accuracy 0.3999
Epoch 15 Batch 300 Loss 0.2615 Accuracy 0.4000
Epoch 15 Batch 350 Loss 0.2637 Accuracy 0.3998
Epoch 15 Batch 400 Loss 0.2643 Accuracy 0.3989
Epoch 15 Batch 450 Loss 0.2661 Accuracy 0.3982
Epoch 15 Batch 500 Loss 0.2682 Accuracy 0.3971
Epoch 15 Batch 550 Loss 0.2701 Accuracy 0.3969
Epoch 15 Batch 600 Loss 0.2725 Accuracy 0.3968
Epoch 15 Batch 650 Loss 0.2751 Accuracy 0.3962
Epoch 15 Batch 700 Loss 0.2777 Accuracy 0.3964
Saving checkpoint for epoch 15 at ./checkpoints/train/ckpt-11
Epoch 15 Loss 0.2777 Accuracy 0.3964
Time taken for 1 epoch: 836.3041331768036 secs
Epoch 16 Batch 0 Loss 0.2658 Accuracy 0.4402
Epoch 16 Batch 50 Loss 0.2486 Accuracy 0.4023
Epoch 16 Batch 100 Loss 0.2487 Accuracy 0.4023
Epoch 16 Batch 150 Loss 0.2513 Accuracy 0.4028
Epoch 16 Batch 200 Loss 0.2527 Accuracy 0.4017
Epoch 16 Batch 250 Loss 0.2547 Accuracy 0.4019
Epoch 16 Batch 300 Loss 0.2576 Accuracy 0.4020
Epoch 16 Batch 350 Loss 0.2606 Accuracy 0.4020
Epoch 16 Batch 400 Loss 0.2622 Accuracy 0.4006
Epoch 16 Batch 450 Loss 0.2637 Accuracy 0.3992
Epoch 16 Batch 500 Loss 0.2654 Accuracy 0.3987
Epoch 16 Batch 550 Loss 0.2675 Accuracy 0.3987
Epoch 16 Batch 600 Loss 0.2697 Accuracy 0.3989
Epoch 16 Batch 650 Loss 0.2718 Accuracy 0.3988
Epoch 16 Batch 700 Loss 0.2738 Accuracy 0.3983
Epoch 16 Loss 0.2741 Accuracy 0.3983
Time taken for 1 epoch: 834.1861200332642 secs
Epoch 17 Batch 0 Loss 0.1989 Accuracy 0.3878
Epoch 17 Batch 50 Loss 0.2413 Accuracy 0.4070
Epoch 17 Batch 100 Loss 0.2439 Accuracy 0.4038
Epoch 17 Batch 150 Loss 0.2469 Accuracy 0.4022
Epoch 17 Batch 200 Loss 0.2501 Accuracy 0.4024
Epoch 17 Batch 250 Loss 0.2516 Accuracy 0.4007
Epoch 17 Batch 300 Loss 0.2536 Accuracy 0.4007
Epoch 17 Batch 350 Loss 0.2562 Accuracy 0.4003
Epoch 17 Batch 400 Loss 0.2588 Accuracy 0.4001
Epoch 17 Batch 450 Loss 0.2606 Accuracy 0.3994
Epoch 17 Batch 500 Loss 0.2620 Accuracy 0.3993
Epoch 17 Batch 550 Loss 0.2641 Accuracy 0.3988
Epoch 17 Batch 600 Loss 0.2657 Accuracy 0.3982
Epoch 17 Batch 650 Loss 0.2679 Accuracy 0.3980
Epoch 17 Batch 700 Loss 0.2700 Accuracy 0.3977
Epoch 17 Loss 0.2702 Accuracy 0.3978
Time taken for 1 epoch: 835.0433349609375 secs
Epoch 18 Batch 0 Loss 0.2166 Accuracy 0.4062
Epoch 18 Batch 50 Loss 0.2412 Accuracy 0.4051
Epoch 18 Batch 100 Loss 0.2425 Accuracy 0.4041
Epoch 18 Batch 150 Loss 0.2451 Accuracy 0.4022
Epoch 18 Batch 200 Loss 0.2479 Accuracy 0.4030
Epoch 18 Batch 250 Loss 0.2504 Accuracy 0.4023
Epoch 18 Batch 300 Loss 0.2532 Accuracy 0.4019
Epoch 18 Batch 350 Loss 0.2544 Accuracy 0.4009
Epoch 18 Batch 400 Loss 0.2555 Accuracy 0.4009
Epoch 18 Batch 450 Loss 0.2565 Accuracy 0.4000
Epoch 18 Batch 500 Loss 0.2590 Accuracy 0.3996
Epoch 18 Batch 550 Loss 0.2612 Accuracy 0.3996
Epoch 18 Batch 600 Loss 0.2636 Accuracy 0.3991
Epoch 18 Batch 650 Loss 0.2653 Accuracy 0.3986
Epoch 18 Batch 700 Loss 0.2671 Accuracy 0.3982
Epoch 18 Loss 0.2672 Accuracy 0.3981
Time taken for 1 epoch: 836.7380259037018 secs
Epoch 19 Batch 0 Loss 0.1827 Accuracy 0.3199
Epoch 19 Batch 50 Loss 0.2371 Accuracy 0.4088
Epoch 19 Batch 100 Loss 0.2367 Accuracy 0.4053
Epoch 19 Batch 150 Loss 0.2392 Accuracy 0.4017
Epoch 19 Batch 200 Loss 0.2416 Accuracy 0.4018
Epoch 19 Batch 250 Loss 0.2446 Accuracy 0.4023
Epoch 19 Batch 300 Loss 0.2470 Accuracy 0.4023
Epoch 19 Batch 350 Loss 0.2496 Accuracy 0.4013
Epoch 19 Batch 400 Loss 0.2515 Accuracy 0.4014
Epoch 19 Batch 450 Loss 0.2535 Accuracy 0.4010
Epoch 19 Batch 500 Loss 0.2560 Accuracy 0.4007
Epoch 19 Batch 550 Loss 0.2584 Accuracy 0.4008
Epoch 19 Batch 600 Loss 0.2598 Accuracy 0.4000
Epoch 19 Batch 650 Loss 0.2621 Accuracy 0.3993
Epoch 19 Batch 700 Loss 0.2640 Accuracy 0.3991
Epoch 19 Loss 0.2641 Accuracy 0.3991
Time taken for 1 epoch: 838.4264571666718 secs
Epoch 20 Batch 0 Loss 0.2607 Accuracy 0.3957
Epoch 20 Batch 50 Loss 0.2378 Accuracy 0.4036
Epoch 20 Batch 100 Loss 0.2369 Accuracy 0.4022
Epoch 20 Batch 150 Loss 0.2393 Accuracy 0.4031
Epoch 20 Batch 200 Loss 0.2416 Accuracy 0.4030
Epoch 20 Batch 250 Loss 0.2420 Accuracy 0.4019
Epoch 20 Batch 300 Loss 0.2442 Accuracy 0.4018
Epoch 20 Batch 350 Loss 0.2467 Accuracy 0.4029
Epoch 20 Batch 400 Loss 0.2497 Accuracy 0.4030
Epoch 20 Batch 450 Loss 0.2506 Accuracy 0.4020
Epoch 20 Batch 500 Loss 0.2525 Accuracy 0.4014
Epoch 20 Batch 550 Loss 0.2545 Accuracy 0.4008
Epoch 20 Batch 600 Loss 0.2575 Accuracy 0.4007
Epoch 20 Batch 650 Loss 0.2598 Accuracy 0.4005
Epoch 20 Batch 700 Loss 0.2615 Accuracy 0.3997
Saving checkpoint for epoch 20 at ./checkpoints/train/ckpt-12
Epoch 20 Loss 0.2615 Accuracy 0.3996
Time taken for 1 epoch: 841.6034660339355 secs
評(píng)估(Evaluate) 以下步驟用于評(píng)估:
用葡萄牙語(yǔ)分詞器(tokenizer_pt)編碼輸入語(yǔ)句。此外,添加開(kāi)始和結(jié)束標(biāo)記,這樣輸入就與模型訓(xùn)練的內(nèi)容相同。這是編碼器輸入。
解碼器輸入為 start token == tokenizer_en.vocab_size。
計(jì)算填充遮擋和前瞻遮擋。
解碼器通過(guò)查看編碼器輸出和它自身的輸出(自注意力)給出預(yù)測(cè)。
選擇最后一個(gè)詞并計(jì)算它的 argmax。
將預(yù)測(cè)的詞連接到解碼器輸入,然后傳遞給解碼器。
在這種方法中,解碼器根據(jù)它預(yù)測(cè)的之前的詞預(yù)測(cè)下一個(gè)。
Note:這里使用的模型具有較小的能力以保持相對(duì)較快,因此預(yù)測(cè)可能不太正確。要復(fù)現(xiàn)論文中的結(jié)果,請(qǐng)使用全部數(shù)據(jù)集,并通過(guò)修改上述超參數(shù)來(lái)使用基礎(chǔ) transformer 模型或者 transformer XL。
def evaluate(inp_sentence):
start_token = [tokenizer_pt.vocab_size]
end_token = [tokenizer_pt.vocab_size + 1]
# 輸入語(yǔ)句是葡萄牙語(yǔ),增加開(kāi)始和結(jié)束標(biāo)記
inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
encoder_input = tf.expand_dims(inp_sentence, 0)
# 因?yàn)槟繕?biāo)是英語(yǔ),輸入 transformer 的第一個(gè)詞應(yīng)該是
# 英語(yǔ)的開(kāi)始標(biāo)記。
decoder_input = [tokenizer_en.vocab_size]
output = tf.expand_dims(decoder_input, 0)
for i in range(MAX_LENGTH):
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)
# predictions.shape == (batch_size, seq_len, vocab_size)
predictions, attention_weights = transformer(encoder_input,
output,
False,
enc_padding_mask,
combined_mask,
dec_padding_mask)
# 從 seq_len 維度選擇最后一個(gè)詞
predictions = predictions[: ,-1:, :] # (batch_size, 1, vocab_size)
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# 如果 predicted_id 等于結(jié)束標(biāo)記,就返回結(jié)果
if predicted_id == tokenizer_en.vocab_size+1:
return tf.squeeze(output, axis=0), attention_weights
# 連接 predicted_id 與輸出,作為解碼器的輸入傳遞到解碼器。
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0), attention_weights
def plot_attention_weights(attention, sentence, result, layer):
fig = plt.figure(figsize=(16, 8))
sentence = tokenizer_pt.encode(sentence)
attention = tf.squeeze(attention[layer], axis=0)
for head in range(attention.shape[0]):
ax = fig.add_subplot(2, 4, head+1)
# 畫(huà)出注意力權(quán)重
ax.matshow(attention[head][:-1, :], cmap='viridis')
fontdict = {'fontsize': 10}
ax.set_xticks(range(len(sentence)+2))
ax.set_yticks(range(len(result)))
ax.set_ylim(len(result)-1.5, -0.5)
ax.set_xticklabels(['<start>']+[tokenizer_pt.decode([i]) for i in sentence]+['<end>'], fontdict=fontdict, rotation=90)
ax.set_yticklabels([tokenizer_en.decode([i]) for i in result if i < tokenizer_en.vocab_size], fontdict=fontdict)
ax.set_xlabel('Head {}'.format(head+1))
plt.tight_layout()
plt.show()
def translate(sentence, plot=''):
result, attention_weights = evaluate(sentence)
predicted_sentence = tokenizer_en.decode([i for i in result
if i < tokenizer_en.vocab_size])
print('Input: {}'.format(sentence))
print('Predicted translation: {}'.format(predicted_sentence))
if plot:
plot_attention_weights(attention_weights, sentence, result, plot)
實(shí)踐翻譯:
translate("este é um problema que temos que resolver.")
print ("Real translation: this is a problem we have to solve .")
translate("os meus vizinhos ouviram sobre esta ideia.")
print ("Real translation: and my neighboring homes heard about this idea .")
translate("vou ent?o muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.")
print ("Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .")
輸出:
Input: este é um problema que temos que resolver.
Predicted translation: so this is a problem that we have to solve ...c . to fix .
Real translation: this is a problem we have to solve .
Input: os meus vizinhos ouviram sobre esta ideia.
Predicted translation: my neighbors heard about this idea of an idea .
Real translation: and my neighboring homes heard about this idea .
Input: vou ent?o muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.
Predicted translation: so i 'm going to spend a few of you could share with you a few really magic stories that happen .
Real translation: so i 'll just share with you some stories very quickly of some magical things that have happened .