深度學(xué)習(xí)模型處理多標(biāo)簽(multi_label)分類任務(wù)——keras實(shí)戰(zhàn)

最近在讀論文的的過(guò)程中接觸到多標(biāo)簽分類(multi-label classification)的任務(wù),必須要強(qiáng)調(diào)的是多標(biāo)簽(multi-label)分類任務(wù) 和 多分類(multi-class)任務(wù)的區(qū)別:

  • 多標(biāo)簽分類任務(wù)指的是一條數(shù)據(jù)可能有一個(gè)或者多個(gè)標(biāo)簽,舉個(gè)例子:比如一個(gè)病人的體檢報(bào)告,它可能被標(biāo)記上,高血壓,高血糖等多個(gè)標(biāo)簽。
  • 多分類任務(wù)指的是一條數(shù)據(jù)只有一個(gè)標(biāo)簽,但是標(biāo)簽有多種類別。機(jī)器學(xué)習(xí)中比較經(jīng)典的iris鳶尾花數(shù)據(jù)集就是標(biāo)準(zhǔn)的多分類任務(wù),一條數(shù)據(jù)喂給模型,模型需判斷它是3個(gè)類別中的哪一個(gè)。

這里筆者強(qiáng)調(diào)一下多標(biāo)簽分類任務(wù)的兩個(gè)特點(diǎn):

  • 類別標(biāo)的數(shù)量是不確定的,有些樣本可能只有一個(gè)類標(biāo),有些樣本可能存在多個(gè)類別標(biāo)簽。
  • 類別標(biāo)簽之間可能存在相互依賴關(guān)系,還是拿我上述的例子來(lái)說(shuō):如果一個(gè)人患有高血壓,他有心血管疾病的概率也會(huì)變大,所以高血壓這個(gè)label和心血管疾病的那些labels是存在一些依賴關(guān)系的。

多標(biāo)簽分類算法簡(jiǎn)介

多標(biāo)簽分類算法比較常用的有ML-KNN、ML-DT、Rank-SVM、CML等。我就不多介紹這些基于傳統(tǒng)機(jī)器學(xué)習(xí)的方法,感興趣的同學(xué)可以自己去研究。這里主要介紹如何采用深度學(xué)習(xí)模型做多標(biāo)簽分類任務(wù),首先我們必須明確一下多標(biāo)簽分類模型的輸入和輸出。

模型輸入輸出

假設(shè)我們有一個(gè)體檢疾病判斷任務(wù):通過(guò)一份體檢報(bào)告判斷一個(gè)人是否患有以下五種?。河行蚺帕小猍高血壓,高血糖,肥胖,肺結(jié)核,冠心病]
輸入:一份體檢報(bào)告
輸出:[1,0,1,0,0 ] ,其中1代表該位置的患病,0代表沒(méi)患病。所以這個(gè)label的含義:患者有高血壓和肥胖。

模型架構(gòu)

接下來(lái)如何建立模型呢:
當(dāng)然你可以對(duì)label的每一個(gè)維度分別進(jìn)行建模,訓(xùn)練5個(gè)二分類器。
但是這樣不僅是的label之間的依賴關(guān)系被破壞,而且還耗時(shí)耗力。接下來(lái)我們還是來(lái)看看深度神經(jīng)網(wǎng)絡(luò)是如何應(yīng)用于此問(wèn)題的。其架構(gòu)如下:

  • 采用神經(jīng)網(wǎng)絡(luò)做特征提取器,這部分不需要多解釋,就是一個(gè)深度學(xué)習(xí)網(wǎng)絡(luò);
  • 采用sigmoid做輸出層的激活函數(shù),若做體檢疾病判斷任務(wù)的話輸出層是5個(gè)節(jié)點(diǎn)對(duì)應(yīng)一個(gè)5維向量,這里沒(méi)有采用softmax,就是希望sigmoid 對(duì)每一個(gè)節(jié)點(diǎn)的值做一次激活,從而輸出每個(gè)節(jié)點(diǎn)分別是 1 概率;
  • 采用binary_crossentropy損失函數(shù)函數(shù),這樣使得模型在訓(xùn)練過(guò)程中不斷降低output和label之間的交叉熵。其實(shí)就相當(dāng)于模型使label為1的節(jié)點(diǎn)的輸出值更靠近1,label為0的節(jié)點(diǎn)的輸出值更靠近0。

有點(diǎn)類似 Structure Learing ,最終模型的輸出就是一個(gè)結(jié)構(gòu)序列。

實(shí)戰(zhàn)部分

導(dǎo)入必要的python包。

import scipy
import pandas as pd
from scipy.io import arff
from sklearn.model_selection import train_test_split
from keras.layers import Dense
import numpy as np
載入數(shù)據(jù)

數(shù)據(jù)是一份益生菌的數(shù)據(jù)集,數(shù)據(jù)集可以從這里下載。數(shù)據(jù)樣式如下圖所示: 共有117列,其中前103列是數(shù)據(jù)的feature,后14列是數(shù)據(jù)的label(都是0或者1)。

data, meta = scipy.io.arff.loadarff('yeast-train.arff')
df = pd.DataFrame(data)
data
數(shù)據(jù)預(yù)處理

這里沒(méi)有做太多的EDA,數(shù)據(jù)清洗等工作,只是將將數(shù)據(jù)的feature和label分開(kāi),同時(shí)將數(shù)據(jù)分成訓(xùn)練集合測(cè)試集。

X = df.iloc[:,0:103].values
y = df.iloc[:,103:117].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
y_train = y_train.astype(np.float64)
y_test = y_test.astype(np.float64)
模型定義

按照上文所描述的模型架構(gòu),搭建了一個(gè)2層的神經(jīng)網(wǎng)絡(luò),每層神經(jīng)元個(gè)數(shù)分別是500和100。

def deep_model(feature_dim,label_dim):
    from keras.models import Sequential
    from keras.layers import Dense
    model = Sequential()
    print("create model. feature_dim ={}, label_dim ={}".format(feature_dim, label_dim))
    model.add(Dense(500, activation='relu', input_dim=feature_dim))
    model.add(Dense(100, activation='relu'))
    model.add(Dense(label_dim, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model
模型訓(xùn)練

定義好batch_size,訓(xùn)練輪數(shù)epoch,將處理好的數(shù)據(jù)喂給模型,就可以跑起來(lái)了。

def train_deep(X_train,y_train,X_test,y_test):
    feature_dim = X_train.shape[1]
    label_dim = y_train.shape[1]
    model = deep_model(feature_dim,label_dim)
    model.summary()
    model.fit(X_train,y_train,batch_size=16, epochs=5,validation_data=(X_test,y_test))
train_deep(X_train,y_train,X_test,y_test)

模型架構(gòu)和訓(xùn)練過(guò)程如下圖所示:模型訓(xùn)練5輪,驗(yàn)證集準(zhǔn)確率就達(dá)到0.8


train

結(jié)語(yǔ)

在現(xiàn)實(shí)生活中很多地方都會(huì)用到多標(biāo)簽分類,因?yàn)榫湍糜脩舢嬒駚?lái)說(shuō),一個(gè)人身上很少只被貼上一個(gè)標(biāo)簽,人是復(fù)雜的,通常是各種標(biāo)簽,各種人設(shè)的集合。所以模型必須學(xué)會(huì)如何分辨和識(shí)別一個(gè)帶有多個(gè)標(biāo)簽的復(fù)雜的事物,這樣的模型才會(huì)是更聰明的模型。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容