首頁 > 軟體

圖文詳解牛頓迭代演演算法原理及Python實現

2022-08-10 18:02:36

1.引例

給定如圖所示的某個函數,如何計算函數零點x0

在數學上我們如何處理這個問題?

最簡單的辦法是解方程f(x)=0,在代數學上還有著名的零點判定定理

如果函數y=f(x)在區間[a,b]上的圖象是連續不斷的一條曲線,並且有f(a)⋅f(b)<0,那麼函數y=f(x)在區間(a,b)內有零點,即至少存在一個c∈(a,b),使得f(c)=0,這個c也就是方程f(x)=0的根。

然而,數學上的方法並不一定適合工程應用,當函數形式複雜,例如出現超越函數形式;非解析形式,例如遞推關係時,精確的方程解析一般難以進行,因為代數上還沒發展出任意形式的求根公式。而零點判定定理求解效率也較低,需要不停試錯。

因此,引入今天的主題——牛頓迭代法,服務於工程數值計算。

2.牛頓迭代演演算法求根

記第k輪迭代後,自變數更新為xk,令目標函數f(x)在x=xk泰勒展開:

f(x)=f(xk​)+f′(xk​)(x−xk​)+o(x)

我們希望下一次迭代到根點,忽略泰勒餘項,令f(xk+1)=0,則

xk+1​=xk​−f(xk​)/f'(xk​)​

不斷重複運算即可逼近根點。

在幾何上,上面過程實際上是在做f(x)在x=xk處的切線,並求切線的零點,在工程上稱為區域性線性化。如圖所示,若xk在x0的左側,那麼下一次迭代方向向右。

若xk在x0的右側,那麼下一次迭代方向向左。

3.牛頓迭代優化

將優化問題轉化為求目標函數一階導數零點的問題,即可運用上面說的牛頓迭代法。

具體地,記第k輪迭代後,自變數更新為xk ,令目標函數f(x)在x=xk泰勒展開:

f(x)=f(xk​)+f′(xk​)(x−xk​)+1/2​f′′(xk​)(x−xk​)2+o(x)

兩邊求導得

f′(x)=f′(xk​)+f′′(xk​)(x−xk​)

令f′(xk+1​)=f′(xk​)+f′′(xk​)(xk+1​−xk​)=0,從而得到

xk+1​=xk​−f′(xk​)/f'′(xk​)​

對於向量x=[x1​​ x2​​⋯​xd​​]T,將上述迭代公式推廣為

xk+1​=xk​−[∇2f(xk​)]−1∇f(xk​)

 

其中∇2f(xk​)是Hessian矩陣,當其正定時可以保證牛頓優化演演算法往 減小的方向迭代

牛頓法的特點如下:

① 以二階速率向最優點收斂,迭代次數遠小於梯度下降法,優化速度快;

梯度下降法的解析參考圖文詳解梯度下降演演算法的原理及Python實現

②學習率為[∇2f(xk​)]−1 ,包含更多函數本身的資訊,迭代步長可實現自動調整,可視為自適應梯度下降演演算法;

③ 耗費CPU計算資源多,每次迭代需要計算一次Hessian矩陣,且無法保證Hessian矩陣可逆且正定,因而無法保證一定向最優點收斂。

在實際應用中,牛頓迭代法一般不能直接使用,會引入改進來規避其缺陷,稱為擬牛頓演演算法簇,其中包含大量不同的演演算法變種,例如共軛梯度法、DFP演演算法等等,今後都會介紹到。

4 程式碼實戰:Logistic迴歸

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit

'''
* @breif: 從CSV中載入指定資料
* @param[in]: file -> 檔名
* @param[in]: colName -> 要載入的列名
* @param[in]: mode -> 載入模式, set: 列名與該列資料組成的字典, df: df型別
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):
    assert mode in ('set', 'df')
    df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
    if mode == 'df':
        return df
    if mode == 'set':
        res = {}
        for col in colName:
            res[col] = df[col].values
        return res

if __name__ == '__main__':
    # ============================
    # 讀取CSV資料
    # ============================
    csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
    dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
    dataY = loadCsvData(csvPath, ["好瓜"], 'df')
    label = np.array([
        1 if i == "是" else 0
        for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
    ])

    # ============================
    # 繪製樣本點
    # ============================
    line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
    mpl.rcParams['font.sans-serif'] = [u'SimHei']
    plt.title('對數機率迴歸模擬nLogistic Regression Simulation')
    plt.xlabel('density')
    plt.ylabel('sugarRate')
    plt.scatter(dataX['密度'][label==0],
                dataX['含糖率'][label==0],
                marker='^',
                color='k',
                s=100,
                label='壞瓜')
    plt.scatter(dataX['密度'][label==1],
                dataX['含糖率'][label==1],
                marker='^',
                color='r',
                s=100,
                label='好瓜')

    # ============================
    # 範例化對數機率迴歸模型
    # ============================
    logit = Logit(dataX, label)

    # 採用牛頓迭代法
    logit.logitRegression(logit.newtomMethod)
    line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
    plt.plot(line_x, line_y, 'g-', label="牛頓迭代法")

    # 繪圖
    plt.legend(loc='upper left')
    plt.show()

其中更新權重程式碼為

    '''
    * @breif: 牛頓迭代法更新權重
    * @param[in]: None
    * @retval: 優化引數的增量dw
    '''
    def newtomMethod(self):
        wTx = np.dot(self.w.T, self.X).reshape(-1, 1)
        p = Logit.sigmod(wTx)
        dw_1 = -self.X.dot(self.y - p)
        dw_2 = self.X.dot(np.diag((p * (1 - p)).reshape(self.N))).dot(self.X.T)
        dw = np.linalg.inv(dw_2).dot(dw_1)
        return dw

到此這篇關於圖文詳解牛頓迭代演演算法原理及Python實現的文章就介紹到這了,更多相關Python牛頓迭代演演算法內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com