數據增強的方式有很多,比如對圖像進行幾何變換(如翻轉、旋轉、變形、縮放等)、顏色變換(包括噪聲、模糊、顏色變換、檫除、填充等),將有限的數據,進行充分的利用。這里將介紹的僅僅是對圖像數據進行任意方向的移動操作(上下左右)來擴充數據。
這里將使用scipy中的shift變換工具(from scipy.ndimage.interpolation import shift)
常用的參數:input輸入圖像數據為ndarray類型的,
shift參數代表表示各個維度的偏移量[1,1]表示第一個第二個維度均偏移1,
cval參數代表偏移后原來位置用什么來填充
from scipy.ndimage.interpolation import shift def shift_digit(digit_array,dx,dy,new = 0): return shift(digit_array.reshape(28,28),[dy,dx],cval = new).reshape(784) plot_digit(shift_digit(some_digit,5,1,new =100))
一個簡單的數據偏移完成,接下來對整個訓練集進行擴充
X_train_expanded = [X_train] y_train_expanded = [y_train] for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)): shifted_image = np.apply_along_axis(shift_digit,axis = 1,arr = X_train,dx = dx,dy = dy) X_train_expanded.append(shifted_image) y_train_expanded.append(y_train) X_train_expanded = np.concatenate(X_train_expanded) y_train_expanded = np.concatenate(y_train_expanded) X_train_expanded.shape,y_train_expanded.shape
數據增加大了30萬之多,有了更多的數據,接下來進行訓練、預測,計算精度
knn_clf.fit(X_train_expanded,y_train_expanded)
y_knn_expanded_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_knn_expanded_pred)
另一種表示方式:
def shift_image(image,dx,dy): image = image.reshape((28,28)) shifted_image = shift(image,[dy,dx],cval = 0,mode = 'constant') return shifted_image.reshape([-1])
image = X_train[1000] shifted_image_down = shift_image(image,0,5) shifted_image_left = shift_image(image,-5,0) plt.figure(figsize=(12,3)) plt.subplot(131) plt.title("Original",fontsize= 14) plt.imshow(image.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.subplot(132) plt.title("shifted down",fontsize= 14) plt.imshow(shifted_image_down.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.subplot(133) plt.title("shifted left",fontsize= 14) plt.imshow(shifted_image_left.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.show()
X_train_augmented = [image for image in X_train] y_train_augmented = [label for label in y_train] for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)): for image,label in zip(X_train,y_train): X_train_augmented.append(shift_image(image,dx,dy)) y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
#打亂順序
shuffle_idx = np.random.permutation(len(X_train_augmented)) X_train_augmented = X_train_augmented[shuffle_idx] y_train_augmented = y_train_augmented[shuffle_idx]
knn_clf = KNeighborsClassifier(**grid_search.best_params_)
knn_clf.fit(X_train_augmented,y_train_augmented)
y_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_pred)
此時准確率已達到97%以上
關於knn_clf = KNeighborsClassifier(**grid_search.best_params_)中的**犯傻了很久,**代表着該參數中包含了多個參數,在C++中也會有這種參數表示,
也可參看python中*args與**kwargs的介紹(https://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained/)
當然,scipy.ndimage.interpolation 也包含了其他的數據增強的方法,如旋轉、縮放等(參考:https://blog.csdn.net/songchunxiao1991/article/details/88531086)