DL4J實戰(zhàn)之二:鳶尾花分類

歡迎訪問我的GitHub

https://github.com/zq2599/blog_demos

內(nèi)容:所有原創(chuàng)文章分類匯總及配套源碼,涉及Java、Docker、Kubernetes、DevOPS等;

本篇概覽

  • 本文是《DL4J》實戰(zhàn)的第二篇,前面做好了準(zhǔn)備工作,接下來進入正式實戰(zhàn),本篇內(nèi)容是經(jīng)典的入門例子:鳶尾花分類
  • 下圖是一朵鳶尾花,我們可以測量到它的四個特征:花瓣(petal)的寬和高,花萼(sepal)的 寬和高:
在這里插入圖片描述
  • 鳶尾花有三種:Setosa、Versicolor、Virginica
  • 今天的實戰(zhàn)是用前饋神經(jīng)網(wǎng)絡(luò)Feed-Forward Neural Network (FFNN)就行鳶尾花分類的模型訓(xùn)練和評估,在拿到150條鳶尾花的特征和分類結(jié)果后,我們先訓(xùn)練出模型,再評估模型的效果:
在這里插入圖片描述

源碼下載

名稱 鏈接 備注
項目主頁 https://github.com/zq2599/blog_demos 該項目在GitHub上的主頁
git倉庫地址(https) https://github.com/zq2599/blog_demos.git 該項目源碼的倉庫地址,https協(xié)議
git倉庫地址(ssh) git@github.com:zq2599/blog_demos.git 該項目源碼的倉庫地址,ssh協(xié)議
  • 這個git項目中有多個文件夾,《DL4J實戰(zhàn)》系列的源碼在<font color="blue">dl4j-tutorials</font>文件夾下,如下圖紅框所示:
在這里插入圖片描述
  • <font color="blue">dl4j-tutorials</font>文件夾下有多個子工程,本次實戰(zhàn)代碼在<font color="blue">dl4j-tutorials</font>目錄下,如下圖紅框:
在這里插入圖片描述

編碼

  • 在<font color="blue">dl4j-tutorials</font>工程下新建子工程<font color="red">classifier-iris</font>,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>dlfj-tutorials</artifactId>
        <groupId>com.bolingcavalry</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>classifier-iris</artifactId>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
    </properties>

    <dependencies>
        <dependency>
            <groupId>com.bolingcavalry</groupId>
            <artifactId>commons</artifactId>
            <version>${project.version}</version>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>
    </dependencies>
</project>
  • 上述pom.xml有一處需要注意的地方,就是<font color="blue">${nd4j.backend}</font>參數(shù)的值,該值在決定了后端線性代數(shù)計算是用CPU還是GPU,本篇為了簡化操作選擇了CPU(因為個人的顯卡不同,代碼里無法統(tǒng)一),對應(yīng)的配置就是<font color="red">nd4j-native</font>;

  • 源碼全部在Iris.java文件中,并且代碼中已添加詳細注釋,就不再贅述了:

package com.bolingcavalry.classifier;

import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;

/**
 * @author will (zq2599@gmail.com)
 * @version 1.0
 * @description: 鳶尾花訓(xùn)練
 * @date 2021/6/13 17:30
 */
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris {

    public static void main(String[] args) throws  Exception {

        //第一階段:準(zhǔn)備

        // 跳過的行數(shù),因為可能是表頭
        int numLinesToSkip = 0;
        // 分隔符
        char delimiter = ',';

        // CSV讀取工具
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);

        // 下載并解壓后,得到文件的位置
        String dataPathLocal = DownloaderUtility.IRISDATA.Download();

        log.info("鳶尾花數(shù)據(jù)已下載并解壓至 : {}", dataPathLocal);

        // 讀取下載后的文件
        recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt")));

        // 每一行的內(nèi)容大概是這樣的:5.1,3.5,1.4,0.2,0
        // 一共五個字段,從零開始算的話,標(biāo)簽在第四個字段
        int labelIndex = 4;

        // 鳶尾花一共分為三類
        int numClasses = 3;

        // 一共150個樣本
        int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

        // 加載到數(shù)據(jù)集迭代器中
        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);

        DataSet allData = iterator.next();

        // 洗牌(打亂順序)
        allData.shuffle();

        // 設(shè)定比例,150個樣本中,百分之六十五用于訓(xùn)練
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

        // 訓(xùn)練用的數(shù)據(jù)集
        DataSet trainingData = testAndTrain.getTrain();

        // 驗證用的數(shù)據(jù)集
        DataSet testData = testAndTrain.getTest();

        // 指定歸一化器:獨立地將每個特征值(和可選的標(biāo)簽值)歸一化為0平均值和1的標(biāo)準(zhǔn)差。
        DataNormalization normalizer = new NormalizerStandardize();

        // 先擬合
        normalizer.fit(trainingData);

        // 對訓(xùn)練集做歸一化
        normalizer.transform(trainingData);

        // 對測試集做歸一化
        normalizer.transform(testData);

        // 每個鳶尾花有四個特征
        final int numInputs = 4;

        // 共有三種鳶尾花
        int outputNum = 3;

        // 隨機數(shù)種子
        long seed = 6;

        //第二階段:訓(xùn)練
        log.info("開始配置...");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .activation(Activation.TANH)       // 激活函數(shù)選用標(biāo)準(zhǔn)的tanh(雙曲正切)
            .weightInit(WeightInit.XAVIER)     // 權(quán)重初始化選用XAVIER:均值 0, 方差為 2.0/(fanIn + fanOut)的高斯分布
            .updater(new Sgd(0.1))  // 更新器,設(shè)置SGD學(xué)習(xí)速率調(diào)度器
            .l2(1e-4)                          // L2正則化配置
            .list()                            // 配置多層網(wǎng)絡(luò)
            .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)  // 隱藏層
                .build())
            .layer(new DenseLayer.Builder().nIn(3).nOut(3)          // 隱藏層
                .build())
            .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)   // 損失函數(shù):負對數(shù)似然
                .activation(Activation.SOFTMAX)                     // 輸出層指定激活函數(shù)為:SOFTMAX
                .nIn(3).nOut(outputNum).build())
            .build();

        // 模型配置
        MultiLayerNetwork model = new MultiLayerNetwork(conf);

        // 初始化
        model.init();

        // 每一百次迭代打印一次分數(shù)(損失函數(shù)的值)
        model.setListeners(new ScoreIterationListener(100));

        long startTime = System.currentTimeMillis();

        log.info("開始訓(xùn)練");
        // 訓(xùn)練
        for(int i=0; i<1000; i++ ) {
            model.fit(trainingData);
        }
        log.info("訓(xùn)練完成,耗時[{}]ms", System.currentTimeMillis()-startTime);

        // 第三階段:評估

        // 在測試集上評估模型
        Evaluation eval = new Evaluation(numClasses);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);

        log.info("評估結(jié)果如下\n" + eval.stats());
    }
}
  • 編碼完成后,運行main方法,可見順利完成訓(xùn)練并輸出了評估結(jié)果,還有混淆矩陣用于輔助分析:
在這里插入圖片描述
  • 至此,咱們的第一個實戰(zhàn)就完成了,通過經(jīng)典實例體驗的DL4J訓(xùn)練和評估的常規(guī)步驟,對重要API也有了初步認識,接下來會繼續(xù)實戰(zhàn),接觸到更多的經(jīng)典實例;

你不孤單,欣宸原創(chuàng)一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 數(shù)據(jù)庫+中間件系列
  6. DevOps系列

歡迎關(guān)注公眾號:程序員欣宸

微信搜索「程序員欣宸」,我是欣宸,期待與您一同暢游Java世界...
https://github.com/zq2599/blog_demos

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容