機器學習_學習筆記系列(11):正則化(Regularization)、Lasso Regression 和Ridge Regression

--

上一篇我們提到,利用多項式迴歸,將複雜的資料來做擬合,但是我們可以發現到,如果我們用的次方數太大,雖然對於訓練集的擬合效果很好,但是對於測試集的預測效果卻非常差,而這就是所謂的overfitting問題。那對於這種情況我們該怎麼樣避免overfitting?這個我們就需要用到正則化(Regularization)。

而這次我們拿比較簡單的case,就是預測「台灣人口成長率變化」,所以這個問題應該不難理解,

年份為我們的輸入資料,人口成長率為我們的輸出資料

然後我們的資料來源為中華民國內政部戶政司全球資訊網[1]。其記載了民國35年到民國108年的人口資料。而這裡我們把民國36年到108年的資料,

隨機抽取80%作為我們的訓練集,而剩下的20%作為我們的測試集。

之後我們將年份作為x軸,人口成長率作為y軸畫出

我們可以發現一件事情,台灣的人口成長率快要變成負的啦!!!。而在這裡呢,為了讓我們在做機器學習運算的時候不會出現計算數值過大使電腦不能運算的情況,我們會將數據做正歸化(Normalize)的動作。

為什麼要做呢?還記得我們做多項式回歸的時候,我們需要將我們的輸入資料x做特徵轉換,所以我們現在可以看到,如果我們對我們的model做5次方的特徵轉換,讀到西元2000年的數據,5次方轉換後數值會直接變成3200兆,如果次方數提到更高,會使得數值遠遠超過電腦可以儲存的長度。所以我們在機器學習一定會做Normalize這個動作,也就是

其中x為我們原始數據,mu為所有x數據的平均,sigma為所有x數據的標準差。所以我們同時對x輸入資料,和y輸出資料做Normalize就會變成

接著我們分別用多項式回歸擬合

我們可以看到在高維度的特徵轉換,回歸線為了要擬合training set的數據點,會開始出現各種扭曲的線條,使得最後在testing set上的錯誤率很高,我們把1到6次方的結果全部畫出來

我們可以發現從4次方以上就開始出現過度擬合的狀況,使得測試集錯誤率明顯上升,所以我們必須進一步使用正則化來限制我們的回歸曲線的

Regularization

在機器學習中,我們是透過微分cost function來更新我們的權重,而在這裡我們regularization要做的就是在我們的cost function後面加上新的一項,來限制我們最後出來的擬合線。

而對於Regularization有兩種方式,一種是用L1 norm做,另一種是L2 norm做。關於L1和L2的概念,這個在高等微積分中會有比較深刻和詳細的定義。我這邊就簡單說明一下

假設

對於L1 norm(曼哈頓範數),其定義為

對於L2 norm(歐幾里德範數),其定義為

所以今天x=(x1,x2),把他畫成圖長這樣

這裡如果我們設訂一個參數C則

而在我們原本的cost function 為:

而將L1 norm和L2 norm應用在我們的cost function上為

接著一樣,如果不知道為什麼這樣設計,我們暴力展開就對了

對於L1

對於L2

套入更新方程式

由上述兩個更新方程式,我可以看到,每次權重更新時,如果權重為正,下一次更新的值會比原本沒引入regularization時更新的值小,相對的權重為負,下一次更新的值會比原本沒引入regularization時更新的值大。意思就是我們的權重被L1和L2限制住了。

更具體來說,我們可以發現每次做特徵轉換的時候,當出現overfitting現象,我們可以看到回歸曲線變得非常複雜,而這些複雜的曲線,有些是來自於方程式中的高次方項所造成,所以我們引入L1 norm和L2 norm的目的,是希望限制我們的權重大小,讓曲線變得不會那麼複雜。

回到一開始cost function的式子

我們的目的是最小化cost function,而在這裡我們可以把他當作同時最小化兩個函數,一個為最小化預測值和實際值差距的J,另一個為我們regularization的函數。若我們今天的假設方程式h(x)=w0+w1*x,把他的最佳化問題視覺化為

我們可以看到,原本我們沒有用regularization得出的解為w,而因為今天這個解出現overfitting的情況,所以我們用L1或L2 norm限制,讓我們最後得出的解為w’,如此一來就能減少過度擬合的情況發生。

所以如果把它應用到高維度的特徵轉換可以得到

而當我們用L1 norm套用再回歸分析上,叫做Lasso regression,而用L2 norm叫做Ridge regression。接著我們用python把他實現化

我們可以看到用regularization後,曲線會變得比較平滑一點,對於模型的泛化能力也相對提高,使我們的測試集錯誤率可以得到明顯的下降

Python Sample Code

Github:

Reference:

[1] 中華民國內政部戶政司全球資訊 https://www.ris.gov.tw/app/portal/346

***本系列完全沒有開任何營利***

作者:劉智皓

linkedin: CHIH-HAO LIU

--

--

劉智皓 (Chih-Hao Liu)
劉智皓 (Chih-Hao Liu)

Written by 劉智皓 (Chih-Hao Liu)

豬屎屋AI Engineer,熱愛AI研究、LLM/SD模型、RAG應用、CUDA/HIP加速運算、訓練推論加速,同時也是5G技術愛好者,研讀過3GPP/ETSI/O-RAN Spec和O-RAN/ONAP/OAI開源軟體。