利用keras從實例掌握深度學(xué)習(xí)圖像分類

0. 背景

之前寫過一個 keras 進行圖像分類的教程,同時也便于自己使用,進行了開源。經(jīng)過一段時間的學(xué)習(xí),雖然已不再使用 keras 和 tensorflow 作為深度學(xué)習(xí)框架進行項目開發(fā),但是 keras 的簡潔性還是值得新手選擇使用的。這里完善一下教程和代碼。建議:框架只是工具而已,還是多學(xué)理論和論文的好

完整代碼:keras_image_classifier

完整教程:利用 keras 從實例掌握深度學(xué)習(xí)圖像分類

個人博客:超杰

要求:

keras==2.2.0 tensorflow==1.8.0 (盡量保持一致,版本不同帶來問題,請自行谷歌)

1. 更新

2018年12月29日 第一次更新:更新全部文檔和代碼

2. 聲明

開源只是幫助一些需要幫助的人,如果有疑問歡迎咨詢,但是代碼已做過線下調(diào)試,確認無誤才進行發(fā)布的,在使用過程中如果遇到一些 bug 請自行谷歌或百度,如果仍有疑問歡迎聯(lián)系我,聯(lián)系方式:zhuchaojie@buaa.edu.cn

另外深度學(xué)習(xí)也好機器學(xué)習(xí)也好,請學(xué)點編程吧。。。。

3. 數(shù)據(jù)格式

為了便于大家使用相同的數(shù)據(jù)進行訓(xùn)練,從而熟悉整個過程,然后再轉(zhuǎn)而到自己的實際項目上,這里了提供公開數(shù)據(jù)集交通標志數(shù)據(jù)集 ,下載鏈接 traffic-sign
數(shù)據(jù)存儲格式:

  • data/
    • train/
      • 00000/
      • 00001/
      • 00002/
      • ...
    • test/
      • 00000/
      • 00001/
      • 00002/

關(guān)于數(shù)據(jù)集的劃分:為了驗證模型的效果,我們需要再另設(shè)驗證集,以便在訓(xùn)練過程中驗證模型效果,但是最終的評測結(jié)果,需要在測試集上進行,所以我們使用 sklearn.model_selection.train_test_split() 函數(shù)對訓(xùn)練集進行劃分,比例采用 訓(xùn)練集:驗證集 = 7:3

4. 項目結(jié)構(gòu)介紹

  • checkpoints/
  • config.py
  • model.py
  • data.py
  • main.py

4.1 checkpoints

主要用來存放訓(xùn)練好的模型權(quán)重,默認只保存模型權(quán)重不保存整個網(wǎng)絡(luò)結(jié)構(gòu)。

4.2 config.py

超參數(shù)文件,本文件定義了整個項目要用到的超參數(shù),具體如下:

### define global configs  ###

class DefaultConfigs(object):
    #1. string configs
    train_data = "../data/train/"
    val_data = "../data/val/"  # if exists else use train_test_split to generate val dataset
    test_data = "../data/all/traffic-sign/test/00003/"  # for competitions
    model_name = "NASNetMobile"
    weights_path = "./checkpoints/model.h5"#save weights for predict
    
    #2. numerical configs
    lr = 0.001
    epochs = 50
    num_classes = 62
    image_size = 224
    batch_size = 16
    channels = 3
    gpu = "0"
    
config = DefaultConfigs()

4.3 model.py

模型搭建文件 。由于日常任務(wù)中常使用預(yù)訓(xùn)練好的模型并進行 finetune ,這里只提供使用該版本,至于自己搭建模型,不是本教程的重點所在,如果有需要的請自行谷歌。

提醒:這里為了方便進行更復(fù)雜的分類任務(wù),提供了一個 ensemble 版本的模型搭建過程。


from keras.applications import NASNetMobile
from keras.applications import resnet50
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D,Dense,Flatten,Input,Concatenate,Dropout
from keras.models import Model
from keras.losses import mae, sparse_categorical_crossentropy, binary_crossentropy,categorical_crossentropy
from keras.optimizers import Adam
from config import config

def get_model():
    inputs = Input((config.image_size, config.image_size, config.channels))
    base_model = NASNetMobile(include_top=False, input_shape=(config.image_size, config.image_size, config.channels))#, weights=None
    x = base_model(inputs)
    out1 = GlobalMaxPooling2D()(x)  # GMP feature
    out2 = GlobalAveragePooling2D()(x) # GAP feature
    out3 = Flatten()(x)                # Flatten feature
    out = Concatenate(axis=-1)([out1, out2, out3])  #concate all feature
    out = Dropout(0.5)(out)
    out = Dense(config.num_classes, activation="softmax", name="classifier")(out)
    model = Model(inputs, out)
    model.compile(optimizer=Adam(0.0001), loss=categorical_crossentropy, metrics=['acc'])

    return model

4.4 data.py

自定義的data generator模塊。在之前,個人喜歡將所有數(shù)據(jù)直接加載到內(nèi)存中,然后再進訓(xùn)練,這樣做的好處就是訓(xùn)練過程減弱了大量數(shù)據(jù)的頻繁調(diào)度問題,但是如果數(shù)據(jù)量過大,內(nèi)存吃不消就行不通了,而且在調(diào)試過程中也很麻煩,這里使用 python 的 yield 機制,能夠避免小機無法運行的問題。

本模塊共包含三個部分:get_files() augument() create_train()

  1. get_files() 詳情如下:

主要功能是循環(huán)讀取每個文件夾下的圖片,并根據(jù)路徑信息提取類別信息。

提醒:windows系統(tǒng)請自行修改類別獲取方法,如果不會,請百度或谷歌

修改位置 labels.append(int(file.split("/")[-2]))

def get_files(root,mode):
    #for test
    if mode == "test":
        files = []
        for img in os.listdir(root):
            files.append(root + img)
        files = pd.DataFrame({"filename":files})
        return files
    elif mode != "test":
        #for train and val       
        all_data_path,labels = [],[]
        image_folders = list(map(lambda x:root+x,os.listdir(root)))
        all_images = list(chain.from_iterable(list(map(lambda x:glob(x+"/*"),image_folders))))
        print("loading train dataset")
        for file in tqdm(all_images):
            all_data_path.append(file)
            labels.append(int(file.split("/")[-2]))
        all_files = pd.DataFrame({"filename":all_data_path,"label":labels})
        return all_files
    else:
        print("check the mode please!")

  1. augument()

使用的線上數(shù)據(jù)增強方式,見開源 imgaug,此處不做重復(fù)介紹。

  1. create_train()

使用python的yield機制對數(shù)據(jù)進行加載,函數(shù)定義如下:


class data_generator:
    def __init__(self,data_lists,mode,augument=True):
        self.mode = mode
        self.augment = augument
        self.all_files = data_lists
    def create_train(self):
        images = []
        dataset_info = self.all_files.values
        #embed()
        """
        if not self.mode == "test":
            for index,row in all_files.iterrows():
                images.append((row["filename"],row["label"]))
        else:
            for index,row in all_files.iterrows():
                images.append((row["filename"]))
        """
        while 1:
            shuffle(dataset_info)
            #print(dataset_info)
            for start in range(0,len(dataset_info),config.batch_size):
                end = min(start + config.batch_size,len(dataset_info))
                batch_images = []
                X_train_batch = dataset_info[start:end]
                batch_labels = np.zeros((len(X_train_batch),config.num_classes))
                for i in range(len(X_train_batch)):
                    #print(X_train_batch[i])
                    image = cv2.imread(X_train_batch[i][0])
                    image = cv2.resize(image,(config.image_size,config.image_size),interpolation=cv2.INTER_NEAREST)

                    if self.augument:
                        image = self.augument(image)
                    batch_images.append(image / 255.)
                    if not self.mode == "test":
                        batch_labels[i][X_train_batch[i][1]] = 1
                        #print(np.array(batch_images).shape)
                yield np.array(batch_images, np.float32), batch_labels

4.5 main.py

包含功能:

  • create_callbacks(),主要實現(xiàn)提前停止訓(xùn)練權(quán)重保存、學(xué)習(xí)率衰減
  • train(),訓(xùn)練函數(shù) ,執(zhí)行訓(xùn)練
  • test(),測試函數(shù),返回每個圖片的預(yù)測類別
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.preprocessing.image import img_to_array
from model import get_model
from config import config
from data import data_generator,get_files
import os
import warnings
import numpy as np
from IPython import embed
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

def create_callbacks():

    early_stop = EarlyStopping(
        monitor         =       "val_acc",
        mode            =       "auto",
        patience        =       30,
        verbose         =       1
    )

    checkpoint = ModelCheckpoint(
        filepath            =       config.weights_path,
        monitor             =       "val_acc",
        save_best_only      =       True,
        save_weights_only   =       True,
        mode                =       "max",
        verbose             =        1
    )

    lr_reducer = ReduceLROnPlateau(
        monitor         =       "val_acc",
        mode            =       "max",
        epsilon         =       0.01,
        factor          =       0.1,
        patience        =       5,
    )
    return [early_stop,checkpoint,lr_reducer]

def train(callbacks):
    #1. compile
    print("--> Compiling the model...")
    model = get_model()
    # load raw train data
    raw_train_data_lists = get_files(config.train_data,"train")
    #split raw train data to train and val
    train_data_lists,val_data_lists = train_test_split(raw_train_data_lists,test_size=0.3)
    # for train
    train_datagen = data_generator(train_data_lists,"train",augument=True).create_train()
    #embed()
    # val data
    val_datagen = data_generator(val_data_lists,"val",augument=True).create_train()  # if model can predict better on augumented data ,the model should be more reboust
    history = model.fit_generator(
        train_datagen,
        validation_data = val_datagen,
        epochs = config.epochs,
        verbose = 1,
        callbacks = callbacks,
        steps_per_epoch=len(train_data_lists) // config.batch_size,
        validation_steps=len(val_data_lists) // config.batch_size
    )
def test():
    test_data_lists = get_files(config.test_data,"test")
    test_datagen = data_generator(test_data_lists,"test",augument=False).create_train()
    model = get_model()
    model.load_weights(config.weights_path)
    predicted_labels = np.argmax(model.predict_generator(test_datagen,steps=len(test_data_lists) / 16),axis=-1)  
    print(predicted_labels) 
if __name__ == "__main__":
    if not os.path.exists("./checkpoints/"):
        os.mkdir("./checkpoints/")
    callbacks = create_callbacks()
    mode = "test"
    if mode == "train":
        train(callbacks)
    elif mode =="test":
        test()
    else:
        print("check mode!")

5 使用方法

  1. 下載數(shù)據(jù)集并解壓 traffic-sign
  2. 修改 config.pytrain_data,test_data 路徑,其中 test_data 示例:test_data = "../data/all/traffic-sign/test/00003/"
  3. 訓(xùn)練:python main.py
  4. 預(yù)測:修改 main.py 中 76 行的 mode = "test" 并執(zhí)行 python main.py
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容