logistic模型能夠對數據進行二分類。
比如我們有兩組二維空間數據,最終要求的是一個分類直線,可以設定為計算w(1)+w(2)*x+w(3)*y=0這樣的直線。
問題就變為了如何求w的問題。
網上有很多推導,這里就不推導了,不過還是要寫幾個關鍵公式。
可以設定logistic函數為:
設定損失函數為:
對J中w求導得到迭代方向:
然后不斷迭代就行了:
下面代碼中y就是data(:,4),即我們的標簽項;x就是data(:,1:3)。
matlab程序如下:
clear all;close all;clc; mu1=[0 0]; S1=[0.5 0.1]; data1=mvnrnd(mu1,S1,100); plot(data1(:,1),data1(:,2),'r.'); hold on; mu2=[1.5 1.5]; S2=[0.4 0.3]; data2=mvnrnd(mu2,S2,100); plot(data2(:,1),data2(:,2),'g.'); data1 = [data1 zeros(length(data1),1)]; data2 = [data2 ones(length(data2),1)]; %兩組數據打標簽 data = [data1;data2]; %數據組合 data = [ones(length(data),1) data]; %數據第一列增加其次項 w = rand(1,3); alpha = 0.01; for i=1:1000 w = w + alpha*(data(:,4)' - 1./(1+exp(-(w*data(:,1:3)'))))*data(:,1:3); %交叉熵求導迭代 end x = min(data(:,2))-1:0.1:max(data(:,2))+1; y = (-w(1)-w(2)*x)/w(3); plot(x,y,'b');
結果如下: