MaDi's Blog

一個紀錄自己在轉職軟體工程師路上的學習小空間

0%

機器學習-正規化(Regularization)

實務上常常會遇到overfitting的問題,所謂overfitting的意思是指在訓練過程中表現很好,但是在測試的結果卻表現很差。

也就是說模型訓練的太過複雜,把簡單的問題想得太複雜,導致產生一堆參數。
於是,就有人提出一個疑問,能不能夠讓複雜的模型倒退回簡單的模型呢?

而這就是 正規化(Regularization) 在做的事情。

正規化(Regularization)

為何要正規化?

實務上常常會遇到overfitting的問題,所謂overfitting的意思是指在訓練過程中表現很好,但是在測試的結果卻表現很差。

也就是說模型訓練的太過複雜,把簡單的問題想得太複雜,導致產生一堆參數。

如下圖所示,我們希望模型學出來是一條回歸的直線,但是他卻學成藍線,因為帶有太多參數導致線型彎彎曲曲的。

於是,就有人提出一個疑問,能不能夠讓複雜的模型倒退回簡單的模型呢?

而這就是 正規化(Regularization) 在做的事情。

L1(Lasso)、L2(ridge)

正規化背後的數學是在原先的loss function(square error或cross entropy)後面額外增加一個正規化的term,通常這個term不考慮bias,因為正規化的目的是希望function可以更平滑,但這點bias無法幫上忙。

正規化分為兩種,分別是L1、L2

L1正規化是把模型裏頭所有的參數都取絕對值。

數學上,因為嚴格來說絕對值不能微分,所以就粗略地把>0的微分為1,<0的微分為-1,以sgn函數表示。

把加上正規化這項term的新的loss function去對每個參數wi做偏微分之後,會發現 **每次update參數wi的時候都會在式子後面扣掉一個 ηλsgn(wi) **,讓參數wi接近於0。

總而言之,L1正規化能夠將模型的複雜度簡化,將沒有用的權重設為0,留下模型認為重要的權重。

L2正規化是把模型裏頭所有的參數都取平方求和。

數學上,把加上正規化這項term的新的loss function去對每個參數wi做偏微分之後,會發現每次update參數wi的時候都會在wi前面乘上一個(1−ηλ),又因為 η 跟 λ 都很小,所以(1−ηλ)大概是0.99左右,小於1但很接近於1。

乘上這個接近1的值,意味著當update的次數愈多,參數wi會愈接近於0(但不會等於0)。

所以L2正規化會讓weight每次都變小一點,這就稱做權重衰減(weight decay)。

總而言之,L2一樣能夠將模型簡化,但不會只留下某個權重,而是削弱所有權重(但仍保留),讓所有權重與神經元都處於活動狀態。

比較L1與L2正規化:

L1正規化: 有可能導致零權重,因刪除更多特徵而使模型稀疏。

L2正規化: 會對更大的權重值造成更大的影響,將使權重值保持較小。

總結一句話來說,我們在訓練模型的過程會update參數,讓參數離0愈來愈遠,而regularization就是把參數拉回來一點,不要離0太遠。

實務上較常用L2正規化,但在CNN裡面適合的是L1,因為結果較稀疏(sparse)。

程式碼中是用參數penalty去選擇用哪一種正規化。

參考: