PyTorch系列三:邏輯回歸

1.介紹

  • logistic回歸是一種廣義線性回歸,因此與多重線性回歸分析有很多相同之處。它們的模型形式基本上相同,都具有 w'x+b,其中w和b是待求參數(shù),其區(qū)別在于他們的因變量不同,多重線性回歸直接將w‘x+b作為因變量,即y =w'x+b,而logistic回歸則通過函數(shù)L將w‘x+b對應(yīng)一個隱狀態(tài)p,p =L(w'x+b),然后根據(jù)p 與1-p的大小決定因變量的值。如果L是logistic函數(shù),就是logistic回歸,如果L是多項式函數(shù)就是多項式回歸。
  • logistic回歸的因變量可以是二分類的,也可以是多分類的,但是二分類的更為常用,也更加容易解釋,多類可以使用softmax方法進行處理。實際中最為常用的就是二分類的logistic回歸。

2.模型訓(xùn)練

# -*- coding: utf-8 -*-

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

import time

# 定義超參數(shù)
batch_size    = 32
num_epoches   = 10
learning_rate = 1e-3

# 下載訓(xùn)練集 MNIST 手寫數(shù)字訓(xùn)練集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset  = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader   = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class Logstic_Regression(nn.Module):
    """邏輯回歸模型定義"""

    def __init__(self, in_dim, n_class):
        super(Logstic_Regression, self).__init__()
        self.logstic = nn.Linear(in_dim, n_class)

    def forward(self, x):
        # 前向傳播
        output = self.logstic(x)
        return output

# 模型初始化
model = Logstic_Regression(28 * 28, 10)  # 圖片大小是28x28

# 定義loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 模型訓(xùn)練
for epoch in range(num_epoches):
    print '#' * 45
    print 'Epoch {}'.format(epoch + 1)
    since = time.time()
    running_loss = 0.0
    running_acc  = 0.0
    for i, data in enumerate(train_loader, 1):
        img, label = data
        img = img.view(img.size(0), -1)

        img   = Variable(img)
        label = Variable(label)

        # 前向傳播
        out = model(img)
        loss = criterion(out, label)
        running_loss += loss.item() * label.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        running_acc += num_correct.item()

        # 后向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 300 == 0:
            print '[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(
                epoch + 1, num_epoches, running_loss / (batch_size * i),
                running_acc / (batch_size * i))
    print 'Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
        epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(
            train_dataset)))

    # 模型評估
    model.eval()
    eval_loss = 0.
    eval_acc  = 0.
    for data in test_loader:
        img, label = data
        img = img.view(img.size(0), -1)

        with torch.no_grad():
            img   = Variable(img)
            label = Variable(label)

        out  = model(img)
        loss = criterion(out, label)
        eval_loss += loss.item() * label.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        eval_acc += num_correct.item()
    print 'Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
        test_dataset)), eval_acc / (len(test_dataset)))
    print 'Time:{:.1f} s'.format(time.time() - since)

# 模型保存
torch.save(model.state_dict(), './Logistic_Regression.model')
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 機器學(xué)習(xí)是做NLP和計算機視覺這類應(yīng)用算法的基礎(chǔ),雖然現(xiàn)在深度學(xué)習(xí)模型大行其道,但是懂一些傳統(tǒng)算法的原理和它們之間...
    在河之簡閱讀 20,941評論 4 65
  • 邏輯回歸 邏輯回歸(Logistic regression 或logit regression),即邏輯模型(英語...
    ChZ_CC閱讀 36,122評論 4 45
  • 以西瓜書為主線,以其他書籍作為參考進行補充,例如《統(tǒng)計學(xué)習(xí)方法》,《PRML》等 第一章 緒論 1.2 基本術(shù)語 ...
    danielAck閱讀 4,940評論 0 5
  • 歲月如梭,白駒過隙,心中焦躁如同火燒,這時候,不要閑著沒事,亂跟童言無忌的孩子發(fā)飆。雖說,他們是祖國的花骨朵,有時...
    土伽丘閱讀 359評論 0 0
  • (和水鬼君詩《端午》而作) 汨羅江水清又純 滋潤三湘眾鄉(xiāng)親 晝夜?jié)L滾流不盡 ...
    張春發(fā)_66a0閱讀 779評論 8 9

友情鏈接更多精彩內(nèi)容