【深度學習實踐】02. Softmax 回歸

雖然稱作 Softmax 回歸,但本質上 Softmax 處理的是分類問題。分類,顧名思義,即對樣本數(shù)據(jù)集進行“哪一個”的解答,通常類別的范圍在分類前是已知的,例如一封電子郵件,我們需要將其分類為垃圾郵件或者非垃圾郵件。通常,機器學習實踐者用分類這個詞來描述以下兩個有微妙差別的問題:

  1. 只對樣本的“硬性”類別感興趣,即想要知道其屬于哪個類別;
  2. “軟性”的類別,即需要得到屬于每個類別的概率(置信度)。

這兩者的界限往往比較模糊。即使我們只關心所屬的類別,我們仍然會使用類別概率的模型。以下展示分類問題的簡單模型,每個 o 代表每個類別的置信度:


分類模型

通常我們通過 獨熱編碼(one-hot encoding)對標簽類別進行編碼。獨熱編碼是一個向量,它的分量和類別一樣多。類別對應的分量設置為1,其他所有分量設置為0。在我們的例子中,標簽y將是一個三維向量,其中(1, 0, 0)對應于“貓”、(0, 1, 0)對應于“雞”、(0, 0, 1)對應于“狗”:
y \in \{(1, 0, 0), (0, 1, 0), (0, 0, 1)\}.

為什么不適用 1,2,3 代表貓、雞,狗的類別,轉化為回歸問題呢?主要是因為三種類別互相之間沒有順序,而1和2的距離比1和3的距離更近,在回歸問題中先驗不同。但如果我們試圖預測\{\text{嬰兒}, \text{兒童},\text{青少年}, \text{青年人}, \text{中年人}, \text{老年人}\},那么將這個問題轉變?yōu)榛貧w問題,并且保留這種格式是有意義的。

softmax

既然我們需要根據(jù)輸出類別的置信度來判斷預測類別,并且通過獨熱編碼來計算損失函數(shù),那么接下來我們需要將輸出的各類別的實數(shù)域上的值縮放到“概率分布”中,也就是確保各類別的輸出和為 1,相對關系保持不變(原先相對較大的值在轉化后仍然較大,負數(shù)要轉化為正值)。softmax 函數(shù)可以幫助我們完成次轉化:
\hat{\mathbf{y}} = \mathrm{softmax}(\mathbf{o}) \quad \text{其中}\quad \hat{y}_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)}
這里,對于所有的j總有0 \leq \hat{y}_j \leq 1。因此,\hat{\mathbf{y}}可以視為一個正確的概率分布。softmax運算不會改變未規(guī)范化的預測\mathbf{o}之間的順序,只會確定分配給每個類別的概率。因此,在預測過程中,我們仍然可以用下式來選擇最有可能的類別。
\operatorname*{argmax}_j \hat y_j = \operatorname*{argmax}_j o_j.

盡管softmax是一個非線性函數(shù),但softmax回歸的輸出仍然由輸入特征的仿射變換決定。因此,softmax回歸是一個線性模型。

損失函數(shù)

我們的標簽使用了獨熱編碼,模型輸出通過softmax轉化為概率分布,此時我們要計算損失函數(shù),即需要衡量兩者概率之間的區(qū)別。這里我們使用最大似然估計,根據(jù)最大似然估計,對于任何標簽\mathbf{y}和模型預測\hat{\mathbf{y}},損失函數(shù)為:
l (\mathbf{y}, \hat{\mathbf{y}}) = - \sum_{j=1}^q y_j \log \hat{y}_j

通常此又被成為交叉熵損失。由于\mathbf{y}是一個長度為q的獨熱編碼向量,所以除了類別項以外的所有項j都消失了。由于所有\hat{y}_j都是預測的概率,所以它們的對數(shù)永遠不會大于0。因此,如果正確地預測實際標簽,即如果實際標簽P(\mathbf{y} \mid \mathbf{x})=1,則損失函數(shù)不能進一步最小化。注意,這往往是不可能的。例如,數(shù)據(jù)集中可能存在標簽噪聲(比如某些樣本可能被誤標),或輸入特征沒有足夠的信息來完美地對每一個樣本分類(簡單來說,模型輸出不可能與獨熱編碼的標簽完全相同)。

對于 損失函數(shù)的導數(shù):
\partial_{o_j} l(\mathbf{y}, \hat{\mathbf{y}}) =\mathrm{softmax}(\mathbf{o})_j - y_j.
換句話說,導數(shù)是我們softmax模型分配的概率與獨熱標簽之間的差異。這使梯度計算在實踐中變得容易很多。

Pytorch 實現(xiàn)

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 我們在線性層前定義了展平層(flatten),來調整網絡輸入的形狀
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
輸出結果
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容