PyTorch學(xué)習(xí)筆記3 - 使用PyTorch實(shí)現(xiàn)Logistic回歸

本篇筆記的完整代碼:https://github.com/ChenWentai/PyTorch/blob/master/task3_logistic.py

1. 準(zhǔn)備數(shù)據(jù)

這次任務(wù)使用Logistic解決二分類(lèi)問(wèn)題。對(duì)于Logistic回歸,數(shù)據(jù)的標(biāo)簽為0和1(而不是1和-1),其中y=0的訓(xùn)練數(shù)據(jù)由均值為2,方差為1正態(tài)分布產(chǎn)生,y=1的訓(xùn)練數(shù)據(jù)由均值為-2, 方差為1的正態(tài)分布產(chǎn)生。

此處數(shù)據(jù)參考Liam Coder的博客https://blog.csdn.net/out_of_memory_error/article/details/81275651

import torch
from torch.autograd import Variable

N = torch.ones(100, 2) #訓(xùn)練樣本數(shù)
x0 = Variable(torch.normal(2*N, 1))
y0 = Variable(torch.zeros(100, 1))
x1 = Variable(torch.normal(-2*N, 1))
y1 = Variable(torch.ones(100, 1))
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
y = torch.cat((y0, y1), 0).type(torch.FloatTensor)

#作出散點(diǎn)圖
fig, ax = plt.subplots()
labels = ['class 0','class 1']
ax.scatter(x.numpy()[0:len(x0),0], x.numpy()[0:len(x0),1], label=labels[0])
ax.scatter(x.numpy()[len(x0):len(x),0], x.numpy()[len(x0):len(x),1], label=labels[1])
ax.legend()

數(shù)據(jù)分布如下:


數(shù)據(jù)分布

2. 使用Pytorch Tensor實(shí)現(xiàn)Logistic回歸

Logistic回歸采用最大似然法求解參數(shù)的最優(yōu)值。 似然函數(shù)如下:

L(w, b) = \sum_{i=1}^{N}[y_{i}log\pi(x_{i}) + (1-y_{i})log(1-\pi(x_{i}))]

其中 i = 1, ..., N 表示有N個(gè)樣本, \pi(x) = 1/(1+exp(-wx)) 是Logistic函數(shù)。通過(guò)梯度下降法可以求得參數(shù) w 的最優(yōu)值。注意,此處的 w 包含了偏置 b.

(1)梯度下降求解參數(shù) wb
#初始化w和b
w = Variable(torch.zeros(2, 1), requires_grad = True)
b = Variable(torch.zeros(1, 1), requires_grad = True)
EPOCHS = 200
likelihood = []
lr = 0.01
for epoch in range(EPOCHS):
    A = 1/(1+torch.exp(-(x.mm(w)+b))) #Logistic函數(shù)
    J =  -torch.mean(y*torch.log(A) + (1-y)*torch.log(1-A)) #對(duì)數(shù)似然函數(shù)
    likelihood.append(-J.data.numpy().item())
    J.backward() #求似然函數(shù)對(duì)w和b的梯度
    w.data = w.data - lr * w.grad.data #更新w
    w.grad.data.zero_()
    b.data = b.data - lr * b.grad.data #更新b
    b.grad.data.zero_()
(2)作出似然函數(shù)J的圖像:
#
import matplotlib.pyplot as plt
plt.plot(likelihood)
plt.ylabel("lieklihood")
plt.xlabel("epoch")
plt.show()
似然函數(shù)J

P.S. 這里似然函數(shù)的公式為J = -torch.mean(y*torch.log(A) + (1-y)*torch.log(1-A)), 由前述的求和項(xiàng)改為了平均值。個(gè)人觀點(diǎn)是為了適應(yīng)PyTorch的求導(dǎo)規(guī)則。如果使用torch.sum(),在梯度下降的過(guò)程中會(huì)出現(xiàn)似然函數(shù)為nan的現(xiàn)象。具體原因有待進(jìn)一步探究。

(3) 作出分類(lèi)邊界圖像: w_{1}x_{1}+w_{2}x_{2}+b=0
xa = list(range(-4, 5))
xb = []
for item in xa:
    xb.append(-(b.data + item*w[0])/w[1])
fig, ax = plt.subplots()
labels = ['class 0','class 1']
ax.scatter(x.numpy()[0:len(x0),0], x.numpy()[0:len(x0),1], label=labels[0])
ax.scatter(x.numpy()[len(x0):len(x),0], x.numpy()[len(x0):len(x),1], label=labels[1])
ax.legend()
plt.plot(xa, xb)
plt.show()
分類(lèi)邊界

3. 使用nn.Module實(shí)現(xiàn)Logistic回歸

(1)搭建nn模型,梯度下降求解參數(shù)wb
from torch import nn
class Logistic(nn.Module):
    def __init__(self):
        super(Logistic, self).__init__()
        self.linear = nn.Linear(2,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y_pred = self.linear(x)
        y_pred = self.sigmoid(y_pred)
        return y_pred
model = Logistic()

criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr= 0.001)
EPOCHS = 1000
costs = []
for epoch in range(EPOCHS):
    x = Variable(x)
    y = Variable(y)
    out = model(x)
    loss = criterion(out, y)
    costs.append(loss.data.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
(2)作出損失函數(shù)的圖像:
#
import matplotlib.pyplot as plt
plt.plot(costs)
plt.show(range(len(costs)), costs)
loss-epoch圖像
(3) 作出分類(lèi)邊界圖像: w_{1}x_{1}+w_{2}x_{2}+b=0
w1, w2 = model.linear.weight[0]
b = model.linear.bias.item()
plot_x = range(-5, 6, 1)
plot_y = [-(w1*item+b)/w2 for item in plot_x]

fig, ax = plt.subplots()
labels = ['class 0','class 1']
ax.scatter(x.numpy()[0:len(x0),0], x.numpy()[0:len(x0),1], label=labels[0])
ax.scatter(x.numpy()[len(x0):len(x),0], x.numpy()[len(x0):len(x),1], label=labels[1])
ax.legend()
ax.plot(plot_x, plot_y)

分類(lèi)邊界
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

友情鏈接更多精彩內(nèi)容