神經網絡的訓練

數據驅動

提取特征量

傳統(tǒng)算法/其他機器學習算法/神經網絡,灰色部分表示沒有人為介入

訓練數據和測試數據

  • 使用訓練數據進行學習,尋找最優(yōu)的參數
  • 使用測試數據評價模型的實際能力
  • 僅用一個數據集去學習和評價參數,容易出現過擬合

損失函數

  • 神經網絡中以某個指標為線索尋找最優(yōu)權重參數,所用的指標為損失函數
  • 均方誤差:E=\frac{1}{2}\sum_k(y_k-t_k)^2,其中y_k 是表示神經網絡的輸出,t_k 表示監(jiān)督數據,k 表示數據的維數
  • 交叉熵誤差:E=-\sum_kt_klog\ y_k,其中y_k是神經網絡輸出對應分類的概率,t_k表示監(jiān)督數據,k表示數據的維數,log表示自然對數
  • 批數據的交叉熵誤差:E=-\frac{1}{N}\sum_n\sum_kt_{nk}log\ y_{nk}
  • mini-batch: 從訓練數據中選出一批數據(稱為mini-batch),用這批數據進行學習,這種方式稱為mini-batch

梯度下降法

梯度(gradient): \nabla f = (\frac{\partial f}{\partial x_1},\frac{\partial f}{\partial x_2},...,\frac{\partial f}{\partial x_n}),梯度是一個向量
梯度下降法:
x_0' = x_0-\eta\frac{\partial f}{\partial x_0} \\ x_1' = x_1-\eta\frac{\partial f}{\partial x_1} \\ 表示成向量形式:\vec {x'} = \vec{x} - \eta·\nabla f\qquad其中\(zhòng)eta稱為學習率 \\ 循環(huán)執(zhí)行這個步驟,當|x'-x|<\epsilon時,凸函數f找到誤差范圍內的最小值

正向傳播的二層神經網絡

步驟

  1. 確定概率預測函數
    predict : (Param,X) \rightarrow \{\\ W_1,W_2,b_1,b_2 = Param[W_1],Param[W_2],Param[b_1],Param[b_1] \\ A_1 = X·W_1^T + b_1 \\ Z_1 = sigmoid(A_1,axis=1) \\ A_2 = Z_1·W_2^T + b_2 \\ Z_2 = softmax(A_2,axis=1) \\ y = Z_2 \\ return\ y\}
  2. 確定T函數(預測正確為1,否則為0)
    t : (Param,X,y) \rightarrow \{ \\ p = predict(Param,X) \\ y_{predict} = argmax(p,axis=1)\\ return\ int(y_{predict} == p,axis=1)\}
  3. 確定交叉熵誤差作為神經網絡損失函數
    loss :(X,Param,y)\rightarrow\{ \\ t_{val} = t(Param,X,y) \\ p = predict(Param,X) \\ return - t_{val}·log(p) \}
  4. Xt作為參數,確定loss函數對變量ParamParam=Param_0處的梯度
    G_0 = \nabla loss(Param)|_{Param = Param_0,X=X_0,y=y_0}
  5. 使用mni_batch更新參數Param,\eta為學習率
    G_{itr} = \nabla loss(Param)|_{Param = Param_{itr},x=x_{batch},y=y_{batch}} \\ Param_{itr+1} = Param_{itr} -G_{itr} * \eta
  6. (可視化時可選)
    周期性地記錄train_acc和test_acc,最后用這兩組記錄繪制學習曲線
  7. 達到最大迭代次數后返回Param_{maxitr}

流程圖表示如下


正向傳播的二層神經網絡

小結

  1. 機器學習中使用數據集分為訓練數據集和測試數據集
  2. 神經網絡用訓練數據進行學習,并用測試數據評價學習到的模型的泛化能力
  3. 神經網絡的學習以損失函數為指標,更新權重參數,以使損失函數的值減小
  4. 利用某個給定的微小值的差分求導數的過程(導數定義),稱為數值微分
  5. 利用數值微分,可以計算權重參數的梯度
  6. 數值微分雖然簡單但是費時間,另一種高速計算導數的方法叫做誤差的反向傳播法
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容