對於深度學習任務,訓練速度決定了模型的迭代速度,而訓練速度又取決於數據預處理和網絡的前向和后向耗時。
對於識別任務,batch size通常較大,並且需要做數據增強,因此常常導致訓練速度的瓶頸在數據讀取和預處理上,尤其對於小網絡而言。
對於數據讀取耗時的提升,粗暴且有效的解決辦法是使用固態硬盤,或者將數據直接拷貝至/tmp文件夾(內存空間換時間)。
對於數據預處理的耗時,則可以通過使用Nvidia官方開發的Dali預處理加速工具包,將預處理放在cpu/gpu上進行加速。pytorch1.6版本內置了Dali,無需自己安裝。
官方的Dali交程較為簡單,實際訓練通常要根據任務需要自定義Dataloader,並於分布式訓練結合使用。這里將展示一個使用Dali定義DataLoader的例子,功能是返回序列圖像,並對序列圖像做常見的統一預處理操作。
`
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.types import DALIImageType
import cv2
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from sklearn.utils import shuffle
import numpy as np
from torchvision import transforms
import torch.utils.data as torchdata
import random
from pathlib import Path
import torch
class TRAIN_INPUT_ITER(object):
def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=True):
self.batch_size = batch_size
self.num_class = num_class
self.seq_len = seq_len
self.sample_rate = sample_rate
self.num_shards = num_shards
self.shard_id = shard_id
self.train = is_training
self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
self.root_dir = root_dir
with open(list_file,'r') as f:
self.ori_lines = f.readlines()
def __iter__(self):
self.i = 0
bucket = len(self.ori_lines)//self.num_shards
self.n = bucket
return self
def __next__(self):
batch = [[] for _ in range(self.seq_len)]
labels = []
for _ in range(self.batch_size):
# self.sample_rate = random.randint(1,2)
if self.train and self.i % self.n == 0:
bucket = len(self.ori_lines)//self.num_shards
self.ori_lines= shuffle(self.ori_lines, random_state=0)
self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
line = self.lines[self.i].strip()
dir_name,start_f,end_f, label = line.split(' ')
start_f = int(start_f)
end_f = int(end_f)
label = int(label)
begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
begin_frame = max(1,begin_frame)
last_frame = None
for k in range(self.seq_len):
filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
if filename.exists():
f = open(filename,'rb')
last_frame = filename
elif last_frame is not None:
f = open(last_frame,'rb')
else:
print('{} does not exist'.format(filename))
raise IOError
batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
if random.randint(0,1)%2 == 0:
end_frame = start_f + random.randint(0,self.sample_rate*self.seq_len//2)
begin_frame = max(1,end_frame-self.sample_rate*self.seq_len)
else:
begin_frame = end_f - random.randint(0,self.sample_rate*self.seq_len//2)
begin_frame = max(1,begin_frame)
end_frame = begin_frame + self.sample_rate*self.seq_len
last_frame = None
for k in range(self.seq_len):
filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
if filename.exists():
f = open(filename,'rb')
last_frame = filename
elif last_frame is not None:
f = open(last_frame,'rb')
else:
print('{} does not exist'.format(filename))
raise IOError
batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
labels.append(np.array([label], dtype = np.uint8))
if label==8 or label == 9:
labels.append(np.array([label], dtype = np.uint8))
else:
labels.append(np.array([self.num_class-1], dtype = np.uint8))
self.i = (self.i + 1) % self.n
return (batch, labels)
next = __next__
class VAL_INPUT_ITER(object):
def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=False):
self.batch_size = batch_size
self.num_class = num_class
self.seq_len = seq_len
self.sample_rate = sample_rate
self.num_shards = num_shards
self.shard_id = shard_id
self.train = is_training
self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
self.root_dir = root_dir
with open(list_file,'r') as f:
self.ori_lines = f.readlines()
self.ori_lines= shuffle(self.ori_lines, random_state=0)
def __iter__(self):
self.i = 0
bucket= len(self.ori_lines)//self.num_shards
self.n = bucket
return self
def __next__(self):
batch = [[] for _ in range(self.seq_len)]
labels = []
for _ in range(self.batch_size):
# self.sample_rate = random.randint(1,2)
if self.train and self.i % self.n == 0:
bucket = len(self.ori_lines)//self.num_shards
self.ori_lines= shuffle(self.ori_lines, random_state=0)
self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
if self.i % self.n == 0:
bucket = len(self.ori_lines)//self.num_shards
self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
line = self.lines[self.i].strip()
dir_name,start_f,end_f, label = line.split(' ')
start_f = int(start_f)
end_f = int(end_f)
label = int(label)
begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
begin_frame = max(1,begin_frame)
last_frame = None
for k in range(self.seq_len):
filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
if filename.exists():
f = open(filename,'rb')
last_frame = filename
elif last_frame is not None:
f = open(last_frame,'rb')
else:
print('{} does not exist'.format(filename))
raise IOError
batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
labels.append(np.array([label], dtype = np.uint8))
self.i = (self.i + 1) % self.n
return (batch, labels)
next = __next__
class HybridPipe(Pipeline):
def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards,shard_id,root_dir, list_file, num_threads, device_id=0, dali_cpu=True,size = (224,224),is_gray = True,is_training = True):
super(HybridPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
if is_training:
self.external_data = TRAIN_INPUT_ITER(batch_size//2, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
else:
self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
# self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
self.seq_len = seq_len
self.training = is_training
self.iterator = iter(self.external_data)
self.inputs = [ops.ExternalSource() for _ in range(seq_len)]
self.input_labels = ops.ExternalSource()
self.is_gray = is_gray
decoder_device = 'cpu' if dali_cpu else 'mixed'
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
if self.is_gray:
self.space_converter = ops.ColorSpaceConversion(device='gpu',image_type=types.RGB,output_type=types.GRAY)
self.resize = ops.Resize(device='gpu', size=size)
self.cast_fp32 = ops.Cast(device='gpu',dtype = types.FLOAT)
if self.training:
self.crop_coin = ops.CoinFlip(probability=0.5)
self.crop_pos_x = ops.Uniform(range=(0., 1.))
self.crop_pos_y = ops.Uniform(range=(0., 1.))
self.crop_h = ops.Uniform(range=(256*0.85,256))
self.crop_w = ops.Uniform(range=(256*0.85,256))
self.crmn = ops.CropMirrorNormalize(device="gpu",output_layout=types.NHWC)
self.u_rotate = ops.Uniform(range=(-8, 8))
self.rotate = ops.Rotate(device='gpu',keep_size=True)
self.brightness = ops.Uniform(range=(0.9,1.1))
self.contrast = ops.Uniform(range=(0.9,1.1))
self.saturation = ops.Uniform(range=(0.9,1.1))
self.hue = ops.Uniform(range=(-0.3,0.3))
self.color_jitter = ops.ColorTwist(device='gpu')
else:
self.crmn = ops.CropMirrorNormalize(device="gpu",crop=(224,224),output_layout=types.NHWC)
def define_graph(self):
self.batch_data = [i() for i in self.inputs]
self.labels = self.input_labels()
out = self.decode(self.batch_data)
out = [out_elem.gpu() for out_elem in out]
if self.training:
out = self.color_jitter(out,brightness=self.brightness(),contrast=self.contrast())
if self.is_gray:
out = self.space_converter(out)
if self.training:
out = self.rotate(out,angle=self.u_rotate())
out = self.crmn(out,crop_h=self.crop_h(),crop_w=self.crop_w(),crop_pos_x=self.crop_pos_x(),crop_pos_y=self.crop_pos_y(),mirror=self.crop_coin())
else:
out = self.crmn(out)
out = self.resize(out)
if not self.training:
out = self.cast_fp32(out)
return (*out, self.labels)
def iter_setup(self):
try:
(batch_data, labels) = self.iterator.next()
for i in range(self.seq_len):
self.feed_input(self.batch_data[i], batch_data[i])
self.feed_input(self.labels, labels)
except StopIteration:
self.iterator = iter(self.external_data)
raise StopIteration
def dali_loader(batch_size,
num_class,
seq_len,
sample_rate,
num_shards,
shard_id,
root_dir,
list_file,
num_workers,
device_id,
dali_cpu=True,
size = (224,224),
is_gray = True,
is_training=True):
print('##########',root_dir)
pipe = HybridPipe(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,
list_file,num_workers,device_id=device_id,
dali_cpu=dali_cpu,size = size,is_gray=is_gray,is_training=is_training)
# pipe.build()
names = []
for i in range(seq_len):
names.append(f'data{i}')
names.append('label')
print('##############',names)
loader = DALIGenericIterator(pipe,names,pipe.external_data.n,last_batch_padded=True, fill_last_batch=True)
return loade
r`