http://www.cnblogs.com/wzm-xu/p/4062266.html
多元線性回歸----Java簡單實現
學習Andrew N.g的機器學習課程之后的簡單實現.
課程地址:https://class.coursera.org/ml-007
不大會編輯公式,所以略去具體的推導,有疑惑的同學去看看Andrew 的課程吧,順帶一句,Andrew的課程實在是很贊。
如果還有疑問,feel free to contact me via emails or QQ.
LinearRegression.java
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
public class LinearRegression {
/*
* 訓練數據示例:
* x0 x1 x2 y
1.0 1.0 2.0 7.2
1.0 2.0 1.0 4.9
1.0 3.0 0.0 2.6
1.0 4.0 1.0 6.3
1.0 5.0 -1.0 1.0
1.0 6.0 0.0 4.7
1.0 7.0 -2.0 -0.6
注意!!!!x1,x2,y三列是用戶實際輸入的數據,x0是為了推導出來的公式統一,特地補上的一列。
x0,x1,x2是“特征”,y是結果
h(x) = theta0 * x0 + theta1* x1 + theta2 * x2
theta0,theta1,theta2 是想要訓練出來的參數
此程序采用“梯度下降法”
*
*/
private double [][] trainData;//訓練數據,一行一個數據,每一行最后一個數據為 y
private int row;//訓練數據 行數
private int column;//訓練數據 列數
private double [] theta;//參數theta
private double alpha;//訓練步長
private int iteration;//迭代次數
public LinearRegression(String fileName)
{
int rowoffile=getRowNumber(fileName);//獲取輸入訓練數據文本的 行數
int columnoffile = getColumnNumber(fileName);//獲取輸入訓練數據文本的 列數
trainData = new double[rowoffile][columnoffile+1];//這里需要注意,為什么要+1,因為為了使得公式整齊,我們加了一個特征x0,x0恆等於1
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = 0.001;//步長默認為0.001
this.iteration=100000;//迭代次數默認為 100000
theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta();
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LinearRegression(String fileName,double alpha,int iteration)
{
int rowoffile=getRowNumber(fileName);//獲取輸入訓練數據文本的 行數
int columnoffile = getColumnNumber(fileName);//獲取輸入訓練數據文本的 列數
trainData = new double[rowoffile][columnoffile+1];//這里需要注意,為什么要+1,因為為了使得公式整齊,我們加了一個特征x0,x0恆等於1
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = alpha;
this.iteration=iteration;
theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta();
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
private int getRowNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
while ( reader.readLine() != null)
count++;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private int getColumnNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = reader.readLine();
if(tempString!=null)
count = tempString.split(" ").length;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private void initialize_theta()//將theta各個參數全部初始化為1.0
{
for(int i=0;i<theta.length;i++)
theta[i]=1.0;
}
public void trainTheta()
{
int iteration = this.iteration;
while( (iteration--)>0 )
{
//對每個theta i 求 偏導數
double [] partial_derivative = compute_partial_derivative();//偏導數
//更新每個theta
for(int i =0; i< theta.length;i++)
theta[i]-= alpha * partial_derivative[i];
}
}
private double [] compute_partial_derivative()
{
double [] partial_derivative = new double[theta.length];
for(int j =0;j<theta.length;j++)//遍歷,對每個theta求偏導數
{
partial_derivative[j]= compute_partial_derivative_for_theta(j);//對 theta i 求 偏導
}
return partial_derivative;
}
private double compute_partial_derivative_for_theta(int j)
{
double sum=0.0;
for(int i=0;i<row;i++)//遍歷 每一行數據
{
sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
}
return sum/row;
}
private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
{
double[] oneRow = getRow(i);//取一行數據,前面是feature,最后一個是y
double result = 0.0;
for(int k=0;k< (oneRow.length-1);k++)
result+=theta[k]*oneRow[k];
result-=oneRow[oneRow.length-1];
result*=oneRow[j];
return result;
}
private double [] getRow(int i)//從訓練數據中取出第i行,i=0,1,2,。。。,(row-1)
{
return trainData[i];
}
private void loadTrainDataFromFile(String fileName,int row, int column)
{
for(int i=0;i< row;i++)//trainData的第一列全部置為1.0(feature x0)
trainData[i][0]=1.0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int counter = 0;
while ( (counter<row) && (tempString = reader.readLine()) != null) {
String [] tempData = tempString.split(" ");
for(int i=0;i<column;i++)
trainData[counter][i+1]=Double.parseDouble(tempData[i]);
counter++;
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
}
public void printTrainData()
{
System.out.println("Train Data:\n");
for(int i=0;i<column-1;i++)
System.out.printf("%10s","x"+i+" ");
System.out.printf("%10s","y"+" \n");
for(int i=0;i<row;i++)
{
for(int j=0;j<column;j++)
{
System.out.printf("%10s",trainData[i][j]+" ");
}
System.out.println();
}
System.out.println();
}
public void printTheta()
{
for(double a:theta)
System.out.print(a+" ");
}
}
TestLinearRegression.java
public class TestLinearRegression {
public static void main(String[] args) {
// TODO Auto-generated method stub
LinearRegression m = new LinearRegression("trainData",0.001,1000000);
m.printTrainData();
m.trainTheta();
m.printTheta();
}
}
trainData文件中是訓練數據,默認最后一列是y,比如:
1.0 2.0 7.2
2.0 1.0 4.9
3.0 0.0 2.6
4.0 1.0 6.3
5.0 -1.0 1.0
6.0 0.0 4.7
7.0 -2.0 -0.6
前兩列是“feature”,最后一列,也就是第三列是y
Email: wuzimian2006@163.com
QQ: 726590906

