機器學習實戰-之SVM核函數與案例


在現實任務中,原始樣本空間中可能不存在這樣可以將樣本正確分為兩類的超平面,但是我們知道如果原始空間的維數是有限的,也就是說屬性數是有限的,則一定存在一個高維特征空間能夠將樣本划分。

事實上,在做任務中,我們並不知道什么樣的核函數是合適的。但是核函數的選擇卻對支持向量機的性能有着至關重要的作用。如果核函數選擇不合適,則意味着樣本映射到一個不合適的特征空間,這樣就有可能導致性能不佳。故“核函數選擇”是非常重要的一項任務。

對於線性數據集的分類來說,我們當然會選擇線性核函數。但如果要分割非線性數據集,我們該如何做呢?答案是,我們可以改變損失函數中的核函數。我們今天就以高斯核函數來進行案例說明:

#導入庫
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
sess=tf.Session()
#生成模擬數據:得到兩個同心圓數據,每個不同的環代表不同的類,分為類-1或者1
(x_vals,y_vals)=datasets.make_circles(n_samples=500,factor=.5,noise=.1)
y_vals=np.array([1 if y==1 else -1 for y in y_vals])
class1_x=[x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y=[x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x=[x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y=[x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]
#聲明批量大小、占位符以及變量b
batch_size=250
x_data=tf.placeholder(shape=[None,2],dtype=tf.float32)
y_target=tf.placeholder(shape=[None,1],dtype=tf.float32)
prediction_grid=tf.placeholder(shape=[None,2],dtype=tf.float32)
b=tf.Variable(tf.random_normal(shape=[1,batch_size]))
#創建高斯函數
gamma=tf.constant(-50.0)
dist=tf.reduce_sum(tf.square(x_data),1)
dist=tf.reshape(dist,[-1,1])
sq_dists=tf.add(tf.subtract(dist,tf.multiply(2.,tf.matmul(x_data,tf.transpose(x_data)))),tf.transpose(dist))
my_kernel=tf.exp(tf.multiply(gamma,tf.abs(sq_dists)))
#PS:線性核函數的表達式可以為:my_kernel=tf.matmul(x_data,tf.transpose(x_data))
#聲明對偶問題,為了最大化,這里采用最小損失函數的負數:tf.negative()
model_output=tf.matmul(b,my_kernel)
first_term=tf.reduce_sum(b)
b_vec_cross=tf.matmul(tf.transpose(b),b)
y_target_cross=tf.matmul(y_target,tf.transpose(y_target))
second_term=tf.reduce_sum(tf.multiply(my_kernel,tf.multiply(b_vec_cross,y_target_cross)))
loss=tf.negative(tf.subtract(first_term,second_term))

  

#創建預測函數和准確度函數
rA=tf.reshape(tf.reduce_sum(tf.square(x_data),1),[-1,1])
rB=tf.reshape(tf.reduce_sum(tf.square(prediction_grid),1),[-1,1])
pred_sq_dist=tf.add(tf.subtract(rA,tf.multiply(2.,tf.matmul(x_data,tf.transpose(prediction_grid)))),tf.transpose(rB))
pred_kernel=tf.exp(tf.multiply(gamma,tf.abs(pred_sq_dist)))

prediction_output=tf.matmul(tf.multiply(tf.transpose(y_target),b),pred_kernel)
prediction=tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy=tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction),tf.squeeze(y_target)),tf.float32))
#創建優化器
my_opt=tf.train.GradientDescentOptimizer(0.001)
train_step=my_opt.minimize(loss)
#初始化變量
init=tf.global_variables_initializer()
sess.run(init)
#迭代訓練,記錄每次迭代的損失向量和准確度
loss_vec=[]
batch_accuracy=[]
for i in range(7500):
    rand_index=np.random.choice(len(x_vals),size=batch_size)
    rand_x=x_vals[rand_index]
    rand_y=np.transpose([y_vals[rand_index]])
    sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})
    temp_loss=sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})
    loss_vec.append(temp_loss)

    acc_temp=sess.run(accuracy,feed_dict={x_data:rand_x,y_target:rand_y,prediction_grid:rand_x})
    batch_accuracy.append(acc_temp)
    if(i+1)%500==0:
        print('step#'+str(i+1))
        print('loss='+str(temp_loss))

#創建數據點網格用於后續的數據空間可視化分類
x_min,x_max=x_vals[:,0].min()-1,x_vals[:,0].max()+1
y_min,y_max=x_vals[:,1].min()-1,x_vals[:,1].max()+1
xx,yy=np.meshgrid(np.arange(x_min,x_max,0.02),
                 np.arange(y_min,y_max,0.02))
grid_points=np.c_[xx.ravel(),yy.ravel()]
[grid_predictions]=sess.run(prediction,feed_dict={x_data:rand_x,
                                                 y_target:rand_y,
                                                 prediction_grid:grid_points})
grid_predictions=grid_predictions.reshape(xx.shape)

  

#繪制預測結果
plt.contourf(xx,yy,grid_predictions,cmap=plt.cm.Paired,alpha=0.8)
plt.plot(class1_x,class1_y,'ro',label='得病')
plt.plot(class2_x,class2_y,'kx',label='沒得病')
plt.legend(loc='lower right')
plt.ylim([-1.5,1.5])
plt.xlim([-1.5,1.5])
plt.show()
#繪制批量結果准確度
plt.plot(batch_accuracy,'k-',label='精確度')
plt.title('批量精確度')
plt.xlabel('迭代次數')
plt.ylabel('精確度')
plt.legend(loc='lower right')
plt.show()

#繪制損失函數
plt.plot(loss_vec,'k-')
plt.title('損失函數/迭代')
plt.xlabel('迭代次數')
plt.ylabel('損失誤差')
plt.show()

  訓練效果與分類結果:

 

更多技術干貨請關注:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM