Pytorch使用Dali進行預處理加速


對於深度學習任務,訓練速度決定了模型的迭代速度,而訓練速度又取決於數據預處理和網絡的前向和后向耗時。
對於識別任務,batch size通常較大,並且需要做數據增強,因此常常導致訓練速度的瓶頸在數據讀取和預處理上,尤其對於小網絡而言。
對於數據讀取耗時的提升,粗暴且有效的解決辦法是使用固態硬盤,或者將數據直接拷貝至/tmp文件夾(內存空間換時間)。
對於數據預處理的耗時,則可以通過使用Nvidia官方開發的Dali預處理加速工具包,將預處理放在cpu/gpu上進行加速。pytorch1.6版本內置了Dali,無需自己安裝。

官方的Dali交程較為簡單,實際訓練通常要根據任務需要自定義Dataloader,並於分布式訓練結合使用。這里將展示一個使用Dali定義DataLoader的例子,功能是返回序列圖像,並對序列圖像做常見的統一預處理操作。
`

from nvidia.dali.plugin.pytorch import DALIGenericIterator

from nvidia.dali.types import DALIImageType
import cv2
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from sklearn.utils import shuffle
import numpy as np
from torchvision import transforms
import torch.utils.data as torchdata
import random
from pathlib import Path
import torch

class TRAIN_INPUT_ITER(object):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=True):
        self.batch_size = batch_size
        self.num_class = num_class
        self.seq_len = seq_len
        self.sample_rate = sample_rate
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.train = is_training
        self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
        self.root_dir = root_dir
        with open(list_file,'r') as f:
            self.ori_lines = f.readlines()

    def __iter__(self):
        self.i = 0
        bucket = len(self.ori_lines)//self.num_shards
        self.n = bucket
        return self

    def __next__(self):
        batch = [[] for _ in range(self.seq_len)]
        labels = []
        for _ in range(self.batch_size):
            # self.sample_rate = random.randint(1,2)
            if self.train and self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.ori_lines= shuffle(self.ori_lines, random_state=0)
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            line = self.lines[self.i].strip()
            dir_name,start_f,end_f, label = line.split(' ')
            start_f = int(start_f)
            end_f = int(end_f)
            label = int(label)
            begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
            begin_frame = max(1,begin_frame)
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
            if random.randint(0,1)%2 == 0:
                end_frame = start_f + random.randint(0,self.sample_rate*self.seq_len//2)
                begin_frame = max(1,end_frame-self.sample_rate*self.seq_len)
            else:
                begin_frame = end_f - random.randint(0,self.sample_rate*self.seq_len//2)
                begin_frame = max(1,begin_frame)
                end_frame = begin_frame + self.sample_rate*self.seq_len
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))

            labels.append(np.array([label], dtype = np.uint8))
            if label==8 or label == 9:
                labels.append(np.array([label], dtype = np.uint8))
            else:
                labels.append(np.array([self.num_class-1], dtype = np.uint8))

            self.i = (self.i + 1) % self.n
        return (batch, labels)
    
    next = __next__


class VAL_INPUT_ITER(object):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=False):
        self.batch_size = batch_size
        self.num_class = num_class
        self.seq_len = seq_len
        self.sample_rate = sample_rate
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.train = is_training
        self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
        self.root_dir = root_dir
        with open(list_file,'r') as f:
            self.ori_lines = f.readlines()
            self.ori_lines= shuffle(self.ori_lines, random_state=0)

    def __iter__(self):
        self.i = 0
        bucket= len(self.ori_lines)//self.num_shards
        self.n = bucket
        return self

    def __next__(self):
        batch = [[] for _ in range(self.seq_len)]
        labels = []
        for _ in range(self.batch_size):
            # self.sample_rate = random.randint(1,2)
            if self.train and self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.ori_lines= shuffle(self.ori_lines, random_state=0)
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            if self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            line = self.lines[self.i].strip()
            dir_name,start_f,end_f, label = line.split(' ')
            start_f = int(start_f)
            end_f = int(end_f)
            label = int(label)
            begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
            begin_frame = max(1,begin_frame)
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
            labels.append(np.array([label], dtype = np.uint8))
            self.i = (self.i + 1) % self.n
        return (batch, labels)
    next = __next__

class HybridPipe(Pipeline):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards,shard_id,root_dir, list_file, num_threads, device_id=0, dali_cpu=True,size = (224,224),is_gray = True,is_training = True):
        super(HybridPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        if is_training:
            self.external_data = TRAIN_INPUT_ITER(batch_size//2, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        else:
            self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        # self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        self.seq_len = seq_len
        self.training = is_training
        self.iterator = iter(self.external_data)
        self.inputs = [ops.ExternalSource() for _ in range(seq_len)]
        self.input_labels = ops.ExternalSource()
        self.is_gray = is_gray

        decoder_device = 'cpu' if dali_cpu else 'mixed'

        self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
        if self.is_gray:
            self.space_converter = ops.ColorSpaceConversion(device='gpu',image_type=types.RGB,output_type=types.GRAY)
        self.resize = ops.Resize(device='gpu', size=size)
        self.cast_fp32 = ops.Cast(device='gpu',dtype = types.FLOAT)
        if self.training:
            self.crop_coin = ops.CoinFlip(probability=0.5)
            self.crop_pos_x = ops.Uniform(range=(0., 1.))
            self.crop_pos_y = ops.Uniform(range=(0., 1.))
            self.crop_h = ops.Uniform(range=(256*0.85,256))
            self.crop_w = ops.Uniform(range=(256*0.85,256))
            self.crmn = ops.CropMirrorNormalize(device="gpu",output_layout=types.NHWC)

            self.u_rotate = ops.Uniform(range=(-8, 8))
            self.rotate = ops.Rotate(device='gpu',keep_size=True)

            self.brightness = ops.Uniform(range=(0.9,1.1))
            self.contrast = ops.Uniform(range=(0.9,1.1))
            self.saturation = ops.Uniform(range=(0.9,1.1))
            self.hue = ops.Uniform(range=(-0.3,0.3))
            self.color_jitter = ops.ColorTwist(device='gpu')
        else:
            self.crmn = ops.CropMirrorNormalize(device="gpu",crop=(224,224),output_layout=types.NHWC)
    

    def define_graph(self):
        self.batch_data = [i() for i in self.inputs]
        self.labels = self.input_labels()
        out = self.decode(self.batch_data)
        out = [out_elem.gpu() for out_elem in out]
        if self.training:
            out = self.color_jitter(out,brightness=self.brightness(),contrast=self.contrast())
        if self.is_gray:
            out = self.space_converter(out)
        if self.training:
            out = self.rotate(out,angle=self.u_rotate())
            out = self.crmn(out,crop_h=self.crop_h(),crop_w=self.crop_w(),crop_pos_x=self.crop_pos_x(),crop_pos_y=self.crop_pos_y(),mirror=self.crop_coin())
        else:
            out = self.crmn(out)
        out = self.resize(out)
        if not self.training:
            out = self.cast_fp32(out)
        return (*out, self.labels)
    
    def iter_setup(self):
        try:
            (batch_data, labels) = self.iterator.next()
            for i in range(self.seq_len):
                self.feed_input(self.batch_data[i], batch_data[i])
            self.feed_input(self.labels, labels)

        except StopIteration:
            self.iterator = iter(self.external_data)
            raise StopIteration

def dali_loader(batch_size,
                num_class,
                seq_len,
                sample_rate,
                num_shards,
                shard_id,
                root_dir,
                list_file,
                num_workers,
                device_id,
                dali_cpu=True,
                size = (224,224),
                is_gray = True,
                is_training=True):
    print('##########',root_dir)
    pipe = HybridPipe(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,
                        list_file,num_workers,device_id=device_id,
                        dali_cpu=dali_cpu,size = size,is_gray=is_gray,is_training=is_training)
    # pipe.build()
    names = []
    for i in range(seq_len):
        names.append(f'data{i}')
    names.append('label')
    print('##############',names)
    loader = DALIGenericIterator(pipe,names,pipe.external_data.n,last_batch_padded=True, fill_last_batch=True)
    return loade

r`


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM