统计学习方法习题笔记——第2章 感知机


第2章思维导图

感知机的三要素

感知机(perceptron)是二类分类的线性分类模型,其输入为实例的特征向量,输出为实例的类别,取+1和-1二值。
模型:假设输入空间(特征空间)是\(\mathcal{X}\subseteq R^n\),输出空间是\(\mathcal{Y}=\{+1,-1\}\)。输入\(x\in\mathcal{X}\)表示实例的特征向量,对应于输入空间的点;输出\(y\in\mathcal{Y}\)表示实例的类别。由输入空间到输出空间的如下函数$$f(x)=sign(w\cdot x+b)$$称为感知机。其中\(w,b\)为感知机模型的参数,\(w\in R^n\)叫做权值或者权值向量,\(b\in R\)叫做偏置,\(w\cdot x\)表示\(w\)\(x\)的内积,sign是符号函数,即

\[sign(x)= \begin{cases} +1, &x\geqslant 0 \\ -1, & x\lt 0 \\ \end{cases}\]

感知机属于线性分类模型,属于判别模型。感知机的假设空间是定义在特征空间的所有线性分类器,即\(\{f|f(x)=w\cdot x + b\}\)
感知机的直观理解
感知机可以理解为特征空间\(R^n\)中的一个超平面S,其中\(w\)是超平面的法向量,\(b\)是超平面的截距。

策略:感知机的学习策略是在假设空间中选取使如下损失函数最小的模型参数,其中\(M\)为误分类点集合。

\[L(w,b)=-\sum_{x_i\in M}y_i(w\cdot x_i + b) \]

算法:感知机的学习方法就是求解损失函数最小的最优化问题。具体地,给定一个训练数据集

\[T=\{(x_1,y_1),(x_2,y_2),...,(x_N,y_N)\} \]

其中\(x_i\in \mathcal{X}=R^n\)\(y_i\in \mathcal{Y}=\{-1,1\}\)\(i=1,2,...,N\),求解参数\(w,b\),使其为以下损失函数的极小化问题的解

\[\mathop{\min}\limits_{w,b}L(w,b)=-\sum_{x_i\in M}y_i(w\cdot x_i + b) \]

感知机采用随机梯度下降法,首先任取一个超平面\(w_0,b_0\),然后使用随机梯度下降法不断极小化目标函数,极小化过程不是一次使\(M\)中所有误分类点梯度下降,而是只选择一个随机误分类点使其梯度下降。
损失函数的梯度为:

\[ \nabla_wL(w,b)=-\sum_{x_i\in M}y_i x_i \\ \nabla_bL(w,b)=-\sum_{x_i \in M}y_i \]

随机选取一个误分类点\((x_i,y_i)\),对\(w,b\)更新:

\[w\leftarrow w + \eta y_i x_i \]

\[b \leftarrow b + \eta y_i \]

其中\(\eta\in (0, 1]\)叫作步长,读音(eta),也叫学习率。
感知机的算法收敛性可由\(Novikoff\)定理证明。

习题

2.1 Minsky 与 Papert 指出:感知机因为是线性模型,所以不能表示复杂的函数,如异或 (XOR)。验证感知机为什么不能表示异或。

\(x_1\) \(x_2\) \(x_1\bigoplus x_2\) \(f(x_1,x_2)\)
0 0 -1 \(b\)
0 1 1 \(sign(w_2+b)\)
1 0 1 \(sign(w_1+b)\)
1 1 -1 \(sign(w_1+w_2+b)\)

如上表所示,假设\(b<0\),则\(w_1>b,w_2>b\),从而\(w_1+w_2+b>0\),这与\(1\bigoplus 1=-1\)矛盾。故感知机不能表示异或。

2.2 模仿例题 2.1,构建从训练数据求解感知机模型的例子。
基于Numpy库的感知机实现。

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np


class BinaryPerceptron(object):

    def __init__(self, n_features=2):
        self.w = np.random.randn(n_features)
        self.b = np.random.rand(1)
        self.history = dict()

    def sign(self, x):
        # sign 函数
        return np.array(x >= 0, np.int32) * 2 - 1

    def fit(self, X, y, max_epochs=20, learning_rate=0.1):
        """
        X: NxD
        y: N
        """
        ws = []
        bs = []
        for epoch in range(max_epochs):
            pred_ = np.sum(self.w * X, axis=1) + self.b
            y_ = self.sign(pred_)

            # 计算准确率和损失
            accuracy = sum((y == y_)) / len(y)
            misclass = np.nonzero((y != y_))
            loss = -np.sum(pred_[misclass] * y[misclass])
            print(f'[{epoch}/{max_epochs}] Accuracy: {accuracy} Loss: {loss:.2f}')
            # 更新权重
            if (accuracy == 1.0): break
            idx = np.random.choice(len(X), 1, p= (y != y_) / sum(y != y_))[0]
            self.w = self.w + learning_rate * X[idx] * y[idx]
            self.b = self.b + learning_rate * y[idx]

            ws.append(self.w)
            bs.append(self.b)

        self.history['ws'] = np.stack(ws)
        self.history['bs'] = np.stack(bs)


if __name__ == '__main__':

    # 创建数据
    X = np.concatenate((np.random.randn(50, 2)-1, (np.random.randn(50, 2) + 1)), axis=0)
    y = np.concatenate((-np.ones(50, dtype=np.int32), np.ones(50, np.int32)), axis=0)
    # 训练模型
    epochs = 20
    perceptron = BinaryPerceptron(2)
    perceptron.fit(X, y, epochs)

    # 绘制动态过程
    history = perceptron.history
    ws, bs = history['ws'], history['bs']

    fig, ax = plt.subplots()
    # 原始数据
    ax.scatter(X[:50,0], X[:50,1])
    ax.scatter(X[50:,0], X[50:,1])

    # 初始分割线
    x1 = np.arange(-3, 3, 0.01)
    x2 = (-ws[0][0] * x1 - bs[0]) / ws[0][1]
    line, = ax.plot(x1, x2)

    # 更新分割线
    def animate(i):
        line.set_ydata((-ws[i][0] * x1 - bs[i]) / ws[i][1])  # update the data.
        return line,

    ani = animation.FuncAnimation(fig, animate, frames=epochs, interval=200, blit=True, repeat=False, repeat_delay=2000)
    ani.save('perceptron_animation.gif')

    plt.show()

参考资料

  1. 《统计学习方法》李航
  2. DataWhale统计学习方法习题解答


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM