通過迭代器獲取數據


%pylab inline
from keras.datasets import mnist
import mxnet as mx
from mxnet import nd
from mxnet import autograd 
import random
from mxnet import gluon

(x_train, y_train), (x_test, y_test) = mnist.load_data()
num_examples = x_train.shape[0]
num_inputs = x_train.shape[1] * x_train.shape[2]
batch_size = 64

1. 自定義數據迭代器

def data_iter1(X, Y, batch_size):
    num_samples = X.shape[0]
    idx = list(range(num_samples))
    random.shuffle(idx)
    
    X = nd.array(X)
    Y = nd.array(Y)
    for i in range(0, num_examples, batch_size):
        j = nd.array(idx[i: min(i + batch_size, num_examples)])
        yield nd.take(X, j), nd.take(Y, j)

2. Gluon 迭代器

dataset = gluon.data.ArrayDataset(x_train, y_train)
data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)

3. 從迭代器中獲取數據

for data, label in data_iter:
    print(data.shape, label.shape)
    break
(64, 28, 28) (64,)
for data, label in data_iter1(x_train, y_train, batch_size):
    print(data.shape, label.shape)
    break
(64, 28, 28) (64,)

更多精彩見:使用 迭代器 獲取 Cifar 等常用數據集


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM