統計學習方法習題筆記——第2章 感知機


第2章思維導圖

感知機的三要素

感知機(perceptron)是二類分類的線性分類模型,其輸入為實例的特征向量,輸出為實例的類別,取+1和-1二值。
模型:假設輸入空間(特征空間)是\(\mathcal{X}\subseteq R^n\),輸出空間是\(\mathcal{Y}=\{+1,-1\}\)。輸入\(x\in\mathcal{X}\)表示實例的特征向量,對應於輸入空間的點;輸出\(y\in\mathcal{Y}\)表示實例的類別。由輸入空間到輸出空間的如下函數$$f(x)=sign(w\cdot x+b)$$稱為感知機。其中\(w,b\)為感知機模型的參數,\(w\in R^n\)叫做權值或者權值向量,\(b\in R\)叫做偏置,\(w\cdot x\)表示\(w\)\(x\)的內積,sign是符號函數,即

\[sign(x)= \begin{cases} +1, &x\geqslant 0 \\ -1, & x\lt 0 \\ \end{cases}\]

感知機屬於線性分類模型,屬於判別模型。感知機的假設空間是定義在特征空間的所有線性分類器,即\(\{f|f(x)=w\cdot x + b\}\)
感知機的直觀理解
感知機可以理解為特征空間\(R^n\)中的一個超平面S,其中\(w\)是超平面的法向量,\(b\)是超平面的截距。

策略:感知機的學習策略是在假設空間中選取使如下損失函數最小的模型參數,其中\(M\)為誤分類點集合。

\[L(w,b)=-\sum_{x_i\in M}y_i(w\cdot x_i + b) \]

算法:感知機的學習方法就是求解損失函數最小的最優化問題。具體地,給定一個訓練數據集

\[T=\{(x_1,y_1),(x_2,y_2),...,(x_N,y_N)\} \]

其中\(x_i\in \mathcal{X}=R^n\)\(y_i\in \mathcal{Y}=\{-1,1\}\)\(i=1,2,...,N\),求解參數\(w,b\),使其為以下損失函數的極小化問題的解

\[\mathop{\min}\limits_{w,b}L(w,b)=-\sum_{x_i\in M}y_i(w\cdot x_i + b) \]

感知機采用隨機梯度下降法,首先任取一個超平面\(w_0,b_0\),然后使用隨機梯度下降法不斷極小化目標函數,極小化過程不是一次使\(M\)中所有誤分類點梯度下降,而是只選擇一個隨機誤分類點使其梯度下降。
損失函數的梯度為:

\[ \nabla_wL(w,b)=-\sum_{x_i\in M}y_i x_i \\ \nabla_bL(w,b)=-\sum_{x_i \in M}y_i \]

隨機選取一個誤分類點\((x_i,y_i)\),對\(w,b\)更新:

\[w\leftarrow w + \eta y_i x_i \]

\[b \leftarrow b + \eta y_i \]

其中\(\eta\in (0, 1]\)叫作步長,讀音(eta),也叫學習率。
感知機的算法收斂性可由\(Novikoff\)定理證明。

習題

2.1 Minsky 與 Papert 指出:感知機因為是線性模型,所以不能表示復雜的函數,如異或 (XOR)。驗證感知機為什么不能表示異或。

\(x_1\) \(x_2\) \(x_1\bigoplus x_2\) \(f(x_1,x_2)\)
0 0 -1 \(b\)
0 1 1 \(sign(w_2+b)\)
1 0 1 \(sign(w_1+b)\)
1 1 -1 \(sign(w_1+w_2+b)\)

如上表所示,假設\(b<0\),則\(w_1>b,w_2>b\),從而\(w_1+w_2+b>0\),這與\(1\bigoplus 1=-1\)矛盾。故感知機不能表示異或。

2.2 模仿例題 2.1,構建從訓練數據求解感知機模型的例子。
基於Numpy庫的感知機實現。

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np


class BinaryPerceptron(object):

    def __init__(self, n_features=2):
        self.w = np.random.randn(n_features)
        self.b = np.random.rand(1)
        self.history = dict()

    def sign(self, x):
        # sign 函數
        return np.array(x >= 0, np.int32) * 2 - 1

    def fit(self, X, y, max_epochs=20, learning_rate=0.1):
        """
        X: NxD
        y: N
        """
        ws = []
        bs = []
        for epoch in range(max_epochs):
            pred_ = np.sum(self.w * X, axis=1) + self.b
            y_ = self.sign(pred_)

            # 計算准確率和損失
            accuracy = sum((y == y_)) / len(y)
            misclass = np.nonzero((y != y_))
            loss = -np.sum(pred_[misclass] * y[misclass])
            print(f'[{epoch}/{max_epochs}] Accuracy: {accuracy} Loss: {loss:.2f}')
            # 更新權重
            if (accuracy == 1.0): break
            idx = np.random.choice(len(X), 1, p= (y != y_) / sum(y != y_))[0]
            self.w = self.w + learning_rate * X[idx] * y[idx]
            self.b = self.b + learning_rate * y[idx]

            ws.append(self.w)
            bs.append(self.b)

        self.history['ws'] = np.stack(ws)
        self.history['bs'] = np.stack(bs)


if __name__ == '__main__':

    # 創建數據
    X = np.concatenate((np.random.randn(50, 2)-1, (np.random.randn(50, 2) + 1)), axis=0)
    y = np.concatenate((-np.ones(50, dtype=np.int32), np.ones(50, np.int32)), axis=0)
    # 訓練模型
    epochs = 20
    perceptron = BinaryPerceptron(2)
    perceptron.fit(X, y, epochs)

    # 繪制動態過程
    history = perceptron.history
    ws, bs = history['ws'], history['bs']

    fig, ax = plt.subplots()
    # 原始數據
    ax.scatter(X[:50,0], X[:50,1])
    ax.scatter(X[50:,0], X[50:,1])

    # 初始分割線
    x1 = np.arange(-3, 3, 0.01)
    x2 = (-ws[0][0] * x1 - bs[0]) / ws[0][1]
    line, = ax.plot(x1, x2)

    # 更新分割線
    def animate(i):
        line.set_ydata((-ws[i][0] * x1 - bs[i]) / ws[i][1])  # update the data.
        return line,

    ani = animation.FuncAnimation(fig, animate, frames=epochs, interval=200, blit=True, repeat=False, repeat_delay=2000)
    ani.save('perceptron_animation.gif')

    plt.show()

參考資料

  1. 《統計學習方法》李航
  2. DataWhale統計學習方法習題解答


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM