Digit Recognizer
在kaggle網站中,competitions里點擊getting started會有一個Digit Recognizer(手寫數字識別)的題目,很適合入門。

The data files train.csv and test.csv contain gray-scale images of hand-drawn digits, from zero through nine.
Each image is 28 pixels in height and 28 pixels in width, for a total of 784 pixels in total. Each pixel has a single pixel-value associated with it, indicating the lightness or darkness of that pixel, with higher numbers meaning darker. This pixel-value is an integer between 0 and 255, inclusive.
The training data set, (train.csv), has 785 columns. The first column, called "label", is the digit that was drawn by the user. The rest of the columns contain the pixel-values of the associated image.
Each pixel column in the training set has a name like pixelx, where x is an integer between 0 and 783, inclusive. To locate this pixel on the image, suppose that we have decomposed x as x = i * 28 + j, where i and j are integers between 0 and 27, inclusive. Then pixelx is located on row i and column j of a 28 x 28 matrix, (indexing by zero).
For example, pixel31 indicates the pixel that is in the fourth column from the left, and the second row from the top, as in the ascii-diagram below.
Visually, if we omit the "pixel" prefix, the pixels make up the image like this:
000 001 002 003 ... 026 027 028 029 030 031 ... 054 055 056 057 058 059 ... 082 083 | | | | ... | | 728 729 730 731 ... 754 755 756 757 758 759 ... 782 783
The test data set, (test.csv), is the same as the training set, except that it does not contain the "label" column.
Your submission file should be in the following format: For each of the 28000 images in the test set, output a single line containing the ImageId and the digit you predict. For example, if you predict that the first image is of a 3, the second image is of a 7, and the third image is of a 8, then your submission file would look like:
ImageId,Label
1,3
2,7
3,8 (27997 more lines)
The evaluation metric for this contest is the categorization accuracy, or the proportion of test images that are correctly classified. For example, a categorization accuracy of 0.97 indicates that you have correctly classified all but 3% of the images.
訓練集是有42001*785的大小組成的,第一行名稱除去 ,第一列是訓練標簽,第二至最后一列是訓練數據。測試數據有28000個。
可以使用隨機森林來實現。
速度的話還行,一兩分鍾就可以訓練出來。
正確率是0.96657
#-*- coding:utf-8 -*- ''' @auther: Starry @file: version3.py @time: 2018/2/12 10:09 ''' from sklearn.ensemble import RandomForestClassifier import numpy as np import pandas as pd import csv def loadData(): train = pd.read_csv('data/train.csv') test = pd.read_csv('data/test.csv') trainData = train.drop(['label'],axis=1).values.astype(dtype=np.int64) trainLabel = np.array(train['label']) testData = test.values print('load Data finish!!!') return trainData, trainLabel, testData def saveResult(testLabel, fileName): header = ['ImageID', 'Label'] with open(fileName, 'w', newline='') as csvFile: writer = csv.writer(csvFile, delimiter=',') writer.writerow(header) for i, p in enumerate(testLabel): writer.writerow([str(i + 1), str(p)]) def RFClassify(trainData, trainLabel, testData): nbCF = RandomForestClassifier(n_estimators=256, warm_start=True) nbCF.fit(trainData, np.ravel(trainLabel)) testLabel = nbCF.predict(testData) saveResult(testLabel, 'output/output5.csv') print('finish!!!') RFClassify(*loadData())
參數優化
參數n_estimators我填的是256,但不一定非要填這個,那么填多少好呢?
通過交叉驗證可以估計個大概來。
def cross_va(train_X,train_y): n_estimators = [100,120,140,160,180,200] test_scores = [] for n_est in n_estimators: print('n_estimators is %s now.'% n_est) clf = RandomForestClassifier(n_estimators=n_est,warm_start=True) test_score = np.sqrt(-cross_val_score(clf,train_X,train_y,cv=5,scoring='neg_mean_squared_error')) print(test_score) test_scores.append(np.mean(test_score)) print(test_scores) plt.plot(n_estimators,test_scores) plt.title('n_estimator vs CV Error') plt.show()
由於每一次算的很慢 我就先在[100,150,200,250]算一下,看哪個更好。
得到下面這張圖:

大概在100-200之間,所以又算了[100,120,140,160,180,200]
得到下圖:
大概在180左右。這也只是個估計,並不是一定的。
在用180提交時得到0.96628的准確率,比256還低。所以又用185提交了,是0.96700,提高了一點點。
用190提交時又減少了,是0.96628,所以參數n_estimators差不多用185更好些。准確率在0.96700。