scikit-learn一般實例之八:多標簽分類


本例模擬一個多標簽文檔分類問題.數據集基於下面的處理隨機生成:

  • 選取標簽的數目:泊松(n~Poisson,n_labels)
  • n次,選取類別C:多項式(c~Multinomial,theta)
  • 選取文檔長度:泊松(k~Poisson,length)
  • k次,選取一個單詞:多項式(w~Multinomial,theta_c)

在上面的處理中,拒絕抽樣用來確保n大於2,文檔長度不為0.同樣,我們拒絕已經被選取的類別.被同事分配給兩個分類的文檔會被兩個圓環包圍.

通過投影到由PCA和CCA選取進行可視化的前兩個主成分進行分類.接着通過元分類器使用兩個線性核的SVC來為每個分類學習一個判別模型.注意,PCA用於無監督降維,CCA用於有監督.

注:在下面的繪制中,"無標簽樣例"不是說我們不知道標簽(就像半監督學習中的那樣),而是這些樣例根本沒有標簽~~~

# coding:utf-8

import numpy as np
from pylab import *

from sklearn.datasets import make_multilabel_classification
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import LabelBinarizer
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA


myfont = matplotlib.font_manager.FontProperties(fname="Microsoft-Yahei-UI-Light.ttc")
mpl.rcParams['axes.unicode_minus'] = False



def plot_hyperplane(clf, min_x, max_x, linestyle, label):
    # 獲得分割超平面
    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(min_x - 5, max_x + 5)  # 確保線足夠長
    yy = a * xx - (clf.intercept_[0]) / w[1]
    plt.plot(xx, yy, linestyle, label=label)


def plot_subfigure(X, Y, subplot, title, transform):
    if transform == "pca":
        X = PCA(n_components=2).fit_transform(X)
    elif transform == "cca":
        X = CCA(n_components=2).fit(X, Y).transform(X)
    else:
        raise ValueError

    min_x = np.min(X[:, 0])
    max_x = np.max(X[:, 0])

    min_y = np.min(X[:, 1])
    max_y = np.max(X[:, 1])

    classif = OneVsRestClassifier(SVC(kernel='linear'))
    classif.fit(X, Y)

    plt.subplot(2, 2, subplot)
    plt.title(title,fontproperties=myfont)

    zero_class = np.where(Y[:, 0])
    one_class = np.where(Y[:, 1])
    plt.scatter(X[:, 0], X[:, 1], s=40, c='gray')
    plt.scatter(X[zero_class, 0], X[zero_class, 1], s=160, edgecolors='b',
               facecolors='none', linewidths=2, label=u'類別-1')
    plt.scatter(X[one_class, 0], X[one_class, 1], s=80, edgecolors='orange',
               facecolors='none', linewidths=2, label=u'類別-2')

    plot_hyperplane(classif.estimators_[0], min_x, max_x, 'k--',
                    u'類別-1的\n邊界')
    plot_hyperplane(classif.estimators_[1], min_x, max_x, 'k-.',
                    u'類別-2的\n邊界')
    plt.xticks(())
    plt.yticks(())

    plt.xlim(min_x - .5 * max_x, max_x + .5 * max_x)
    plt.ylim(min_y - .5 * max_y, max_y + .5 * max_y)
    if subplot == 2:
        plt.xlabel(u'第一主成分',fontproperties=myfont)
        plt.ylabel(u'第二主成分',fontproperties=myfont)
        plt.legend(loc="upper left",prop=myfont)


plt.figure(figsize=(8, 6))

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                      allow_unlabeled=True,
                                      random_state=1)

plot_subfigure(X, Y, 1, u"有無標簽樣例 + CCA", "cca")
plot_subfigure(X, Y, 2, u"有無標簽樣例 + PCA", "pca")

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
                                      allow_unlabeled=False,
                                      random_state=1)

plot_subfigure(X, Y, 3, u"沒有無標簽樣例 + CCA", "cca")
plot_subfigure(X, Y, 4, u"沒有無標簽樣例 + PCA", "pca")

plt.subplots_adjust(.04, .02, .97, .94, .09, .2)
plt.suptitle(u"多標簽分類", size=20,fontproperties=myfont)
plt.show()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM