今天發現一個用 numpy
隨機化數組的技巧。
需求
我有兩個數組( ndarray ):train_datasets 和 train_labels。其中,train_datasets 的每一行和 train_labels 是一一對應的。現在我要將數組打亂並用於訓練,打亂后要求兩者的行與行之間必須保持原來的對應關系。
實現
一般的實現思路,應該是先將 train_datasets(或 train_labels )打亂,並記錄被打亂的行號,再通過行號調整 train_labels (或 train_datasets )的行次序,這樣兩者的對應關系能保持一致。但代碼實現起來會很繁瑣,而如果用上 numpy
的話,可以三行代碼搞定。
首先,假設我們用如下訓練數據(訓練數據和標簽都是三個):
>>> train_data = np.ndarray(shape=(3,1,2), dtype=np.int32, buffer=np.asarray((1,2,3,4,5,6), dtype=np.int32))
>>> train_label = np.ndarray(shape=(3,), dtype=np.int32, buffer=np.asarray((1,2,3), dtype=np.int32))
>>> train_data
array([[[1, 2]],
[[3, 4]],
[[5, 6]]], dtype=int32)
>>> train_label
array([1, 2, 3], dtype=int32)
下面,我們用三行代碼打亂樣本數據:
>>> permutation = np.random.permutation(train_label.shape[0])
>>> shuffled_dataset = train_data[permutation, :, :]
>>> shuffled_labels = train_label[permutation]
稍微解釋一下代碼:
利用 np.random.permutation
函數,我們可以獲得打亂后的行號,輸出permutation
為:array([2, 1, 0])
。
然后,利用 numpy array
內置的操作 train_data[permutation, :, :]
,我們可以獲得打亂行號后的新的訓練數據。
我們看看訓練數據和標簽是不是對應的:
>>> shuffled_dataset
array([[[5, 6]],
[[3, 4]],
[[1, 2]]], dtype=int32)
>>> shuffled_labels
array([3, 2, 1], dtype=int32)
沒錯,完全按照 permutation
[2, 1, 0] 的順序重新調整了。
學會這種技巧,媽媽再也不擔心我加班了🤓