mxnet自定義dataloader加載自己的數據


實際上關於pytorch加載自己的數據之前有寫過一篇博客,但是最近接觸了mxnet,發現關於這方面的教程很少

如果要加載自己定義的數據的話,看mxnet關於mnist基本上能夠推測12

看pytorch與mxnet他們加載數據方式的對比

上圖左邊是pytorch的,右圖是mxnet

實際上,mxnet與pytorch他們的datalayer有着相似之處,為什么這樣說呢?直接看上面的代碼,基本上都是輸入圖像的路徑,然后輸出一個可以供loader調用的可以迭代的對象,所以無論是pytorch或者是mxnet,如果要有自己的數據,只需要在自己的數據那一部分繼承與修改ImageFolderDataset這個函數就行,就是直接繼承dataset.Dataset類即可

對於pytorch而言,它使用了find_class這樣一個函數,而對於mxnet而言,實際上它在類內部定義了一個_list_images的函數,事實上我並沒有發現這有沒有用,只需要get_item這個函數中返回list,list中是一個tuple,一個是文件的名字,另外一個是文件所對應的label即可。

只需要繼承這一個類即可

直接擼代碼

這個是我參加kaggle比賽的一段代碼,盡管並不收斂,但請不要在意這些細節

  1 # -*-coding:utf-8-*-
  2 from mxnet import autograd
  3 from mxnet import gluon
  4 from mxnet import image
  5 from mxnet import init
  6 from mxnet import nd
  7 from mxnet.gluon.data import vision
  8 import numpy as np
  9 from mxnet.gluon.data import dataset
 10 import os
 11 import warnings
 12 import random
 13 from mxnet import gpu
 14 from mxnet.gluon.data.vision import datasets
 15 
 16 class MyImageFolderDataset(dataset.Dataset):
 17     def __init__(self, root, label, flag=1, transform=None):
 18         self._root = os.path.expanduser(root)
 19         self._flag = flag
 20         self._label = label
 21         self._transform = transform
 22         self._exts = ['.jpg', '.jpeg', '.png']
 23         self._list_images(self._root, self._label)
 24 
 25     def _list_images(self, root, label):  # label是一個list
 26         self.synsets = []
 27         self.synsets.append(root)
 28         self.items = []
 29         #file = open(label)
 30         #lines = file.readlines()
 31         #random.shuffle(lines)
 32         c = 0
 33         for line in label:
 34             cls = line.split()
 35             fn = cls.pop(0)
 36             fn = fn + '.jpg'
 37             # print(os.path.join(root, fn))
 38             if os.path.isfile(os.path.join(root, fn)):
 39                 self.items.append((os.path.join(root, fn), float(cls[0])))
 40                 # print((os.path.join(root, fn), float(cls[0])))
 41             else:
 42                 print('what')
 43             c = c + 1
 44         print('the total image is ', c)
 45 
 46     def __getitem__(self, idx):
 47         img = image.imread(self.items[idx][0], self._flag)
 48         label = self.items[idx][1]
 49         if self._transform is not None:
 50             return self._transform(img, label)
 51         return img, label
 52 
 53     def __len__(self):
 54         return len(self.items)
 55 
 56 
 57 def _get_batch(batch, ctx):  # 可以在循環中直接for i, data, label,函數主要把data放在ctx上
 58     """return data and label on ctx"""
 59     if isinstance(batch, mx.io.DataBatch):
 60         data = batch.data[0]
 61         label = batch.label[0]
 62     else:
 63         data, label = batch
 64     return (gluon.utils.split_and_load(data, ctx),
 65             gluon.utils.split_and_load(label, ctx),
 66             data.shape[0])
 67 
 68 def transform_train(data, label):
 69     im = image.imresize(data.astype('float32') / 255, 256, 256)
 70     auglist = image.CreateAugmenter(data_shape=(3, 256, 256), resize=0,
 71                         rand_crop=False, rand_resize=False, rand_mirror=True,
 72                         mean=None, std=None,
 73                         brightness=0, contrast=0,
 74                         saturation=0, hue=0,
 75                         pca_noise=0, rand_gray=0, inter_method=2)
 76     for aug in auglist:
 77         im = aug(im)
 78     # 將數據格式從"高*寬*通道"改為"通道*高*寬"。
 79     im = nd.transpose(im, (2, 0, 1))
 80     return (im, nd.array([label]).asscalar().astype('float32'))
 81 
 82 
 83 def transform_test(data, label):
 84     im = image.imresize(data.astype('float32') / 255, 256, 256)
 85     im = nd.transpose(im, (2, 0, 1))  # 之前沒有運行此變換
 86     return (im, nd.array([label]).asscalar().astype('float32'))
 87 
 88 batch_size = 16
 89 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
 90 def random_choose_data(label_path):
 91     f = open(label_path)
 92     lines = f.readlins()
 93     random.shuffle(lines)
 94     total_number = len(lines)
 95     train_number = total_number/10*7
 96     train_list = lines[:train_number]
 97     test_list = lines[train_number:]
 98     return (train_list, test_list)
 99 
100 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
101 train_list, test_list = random_choose_data(label_path)
102 loader = gluon.data.DataLoader
103 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
104 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
105 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
106 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
107 softmax_cross_entropy = gluon.loss.L2Loss()  # 定義L2 loss
108 
109 
110 from mxnet.gluon import nn
111 
112 net = nn.Sequential()
113 with net.name_scope():
114     net.add(
115         # 第一階段
116         nn.Conv2D(channels=96, kernel_size=11,
117                   strides=4, activation='relu'),
118         nn.MaxPool2D(pool_size=3, strides=2),
119         # 第二階段
120         nn.Conv2D(channels=256, kernel_size=5,
121                   padding=2, activation='relu'),
122         nn.MaxPool2D(pool_size=3, strides=2),
123         # 第三階段
124         nn.Conv2D(channels=384, kernel_size=3,
125                   padding=1, activation='relu'),
126         nn.Conv2D(channels=384, kernel_size=3,
127                   padding=1, activation='relu'),
128         nn.Conv2D(channels=256, kernel_size=3,
129                   padding=1, activation='relu'),
130         nn.MaxPool2D(pool_size=3, strides=2),
131         # 第四階段
132         nn.Flatten(),
133         nn.Dense(4096, activation="relu"),
134         nn.Dropout(.5),
135         # 第五階段
136         nn.Dense(4096, activation="relu"),
137         nn.Dropout(.5),
138         # 第六階段
139         nn.Dense(14950)  # 輸出為1個值
140     )
141 
142 from mxnet import init
143 from mxnet import gluon
144 import mxnet as mx
145 import utils
146 import datetime
147 from time import time
148 
149 ctx = utils.try_gpu()
150 net.initialize(ctx=ctx, init=init.Xavier())
151 
152 mse_loss = gluon.loss.L2Loss()
153 
154 # utils.train(train_data, test_data, net, loss,
155 #             trainer, ctx, num_epochs=10)
156 #def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):
157 num_epochs = 10
158 print_batches = 100
159 """Train a network"""
160 print("Start training on ", ctx)
161 if isinstance(ctx, mx.Context):
162     ctx = [ctx]
163 def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):
164     trainer = gluon.Trainer(net.collect_params(), 'sgd',
165                             {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
166     prev_time = datetime.datetime.now()
167     for epoch in range(num_epochs):
168         train_loss = 0.0
169         if epoch > 0 and epoch % lr_period == 0:
170             trainer.set_learning_rate(trainer.learning_rate*lr_decay)
171         for data, label in train_data:
172             label = label.as_in_context(ctx)
173             with autograd.record():
174                 output = net(data.as_in_context(ctx))
175                 loss = mse_loss(output, label)
176             loss.backward()
177             trainer.step(batch_size)  # do the update, Trainer needs to know the batch size of the data to normalize
178             # the gradient by 1/batch_size
179             train_loss += nd.mean(loss).asscalar()
180             print(nd.mean(loss).asscalar())
181         cur_time = datetime.datetime.now()
182         h, remainder = divmod((cur_time - prev_time).seconds, 3600)
183         m, s = divmod(remainder, 60)
184         time_str = "Time %02d:%02d:%02d" % (h, m, s)
185         epoch_str = ('Epoch %d. Train loss: %f, ' % (epoch, train_loss / len(train_data)))
186         prev_time = cur_time
187         print(epoch_str + time_str + ', lr' + str(trainer.learning_rate))
188     net.collect_params().save('./model/alexnet.params')
189 ctx = utils.try_gpu()
190 num_epochs = 100
191 learning_rate = 0.001
192 weight_decay = 5e-4
193 lr_period = 10
194 lr_decay = 0.1
195 
196 train(net, train_data, test_data, num_epochs, learning_rate,
197       weight_decay, ctx, lr_period, lr_decay)
View Code

 

請看這一段

 1 class MyImageFolderDataset(dataset.Dataset):
 2     def __init__(self, root, label, flag=1, transform=None):
 3         self._root = os.path.expanduser(root)
 4         self._flag = flag
 5         self._label = label
 6         self._transform = transform
 7         self._exts = ['.jpg', '.jpeg', '.png']
 8         self._list_images(self._root, self._label)
 9 
10     def _list_images(self, root, label):  # label是一個list
11         self.synsets = []
12         self.synsets.append(root)
13         self.items = []
14         #file = open(label)
15         #lines = file.readlines()
16         #random.shuffle(lines)
17         c = 0
18         for line in label:
19             cls = line.split()
20             fn = cls.pop(0)
21             fn = fn + '.jpg'
22             # print(os.path.join(root, fn))
23             if os.path.isfile(os.path.join(root, fn)):
24                 self.items.append((os.path.join(root, fn), float(cls[0])))
25                 # print((os.path.join(root, fn), float(cls[0])))
26             else:
27                 print('what')
28             c = c + 1
29         print('the total image is ', c)
30 
31     def __getitem__(self, idx):
32         img = image.imread(self.items[idx][0], self._flag)
33         label = self.items[idx][1]
34         if self._transform is not None:
35             return self._transform(img, label)
36         return img, label
37 
38     def __len__(self):
39         return len(self.items)
40 batch_size = 16
41 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
42 def random_choose_data(label_path):
43     f = open(label_path)
44     lines = f.readlins()
45     random.shuffle(lines)
46     total_number = len(lines)
47     train_number = total_number/10*7
48     train_list = lines[:train_number]
49     test_list = lines[train_number:]
50     return (train_list, test_list)
51 
52 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
53 train_list, test_list = random_choose_data(label_path)
54 
55 loader = gluon.data.DataLoader
56 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
57 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
58 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
59 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
View Code

MyImageFolderDataset是dataset.Dataset的子類,主要是是重載索引運算__getitem__,並且返回image以及其對應的label即可,前面的的_list_image函數只要是能夠返回item這個list就行,關於運算符重載給自己挖個坑

可以說和pytorch非常像了,就連沐神在講課的時候還在說,其實在寫mxnet的時候,借鑒了很多pytorch的內容

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM