pytorch對一下常用的公開數據集有很方便的API接口,但是當我們需要使用自己的數據集訓練神經網絡時,就需要自定義數據集,在pytorch中,提供了一些類,方便我們定義自己的數據集合
- torch.utils.data.Dataset:所有繼承他的子類都應該重寫 __len()__ , __getitem()__ 這兩個方法
- __len()__ :返回數據集中數據的數量
- __getitem()__ :返回支持下標索引方式獲取的一個數據
- torch.utils.data.DataLoader:對數據集進行包裝,可以設置batch_size、是否shuffle....
第一步
自定義的 Dataset 都需要繼承 torch.utils.data.Dataset 類,並且重寫它的兩個成員方法:
- __len()__:讀取數據,返回數據和標簽
- __getitem()__:返回數據集的長度
from torch.utils.data import Dataset class AudioDataset(Dataset): def __init__(self, ...): """類的初始化""" pass def __getitem__(self, item): """每次怎么讀數據,返回數據和標簽""" return data, label def __len__(self): """返回整個數據集的長度""" return total
注意事項:Dataset只負責數據的抽象,一次調用getiitem只返回一個樣本
案例:
文件目錄結構
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:讀取p225文件夾中的音頻數據
class AudioDataset(Dataset): def __init__(self, data_folder, sr=16000, dimension=8192): self.data_folder = data_folder self.sr = sr self.dim = dimension # 獲取音頻名列表 self.wav_list = [] for root, dirnames, filenames in os.walk(data_folder): for filename in fnmatch.filter(filenames, "*.wav"): # 實現列表特殊字符的過濾或篩選,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root, filename)) def __getitem__(self, item): # 讀取一個音頻文件,返回每個音頻數據 filename = self.wav_list[item] wb_wav, _ = librosa.load(filename, sr=self.sr) # 取 幀 if len(wb_wav) >= self.dim: max_audio_start = len(wb_wav) - self.dim audio_start = np.random.randint(0, max_audio_start) wb_wav = wb_wav[audio_start: audio_start + self.dim] else: wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant") return wb_wav, filename def __len__(self): # 音頻文件的總數 return len(self.wav_list)
注意事項:19-24行:每個音頻的長度不一樣,如果直接讀取數據返回出來的話,會造成維度不匹配而報錯,因此只能每次取一個音頻文件讀取一幀,這樣顯然並沒有用到所有的語音數據,
第二步
實例化 Dataset 對象
Dataset= AudioDataset("./p225", sr=16000)
如果要通過batch讀取數據的可直接跳到第三步,如果你想一個一個讀取數據的可以看我接下來的操作
# 實例化AudioDataset對象 train_set = AudioDataset("./p225", sr=16000) for i, data in enumerate(train_set): wb_wav, filname = data print(i, wb_wav.shape, filname) if i == 3: break # 0 (8192,) ./p225\p225_001.wav # 1 (8192,) ./p225\p225_002.wav # 2 (8192,) ./p225\p225_003.wav # 3 (8192,) ./p225\p225_004.wav
第三步
如果想要通過batch讀取數據,需要使用DataLoader進行包裝
為何要使用DataLoader?
- 深度學習的輸入是mini_batch形式
- 樣本加載時候可能需要隨機打亂順序,shuffle操作
- 樣本加載需要采用多線程
pytorch提供的 DataLoader 封裝了上述的功能,這樣使用起來更方便。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
參數:
- dataset:加載的數據集(Dataset對象)
- batch_size:每個批次要加載多少個樣本(默認值:1)
- shuffle:每個epoch是否將數據打亂
- sampler:定義從數據集中抽取樣本的策略。如果指定,則不能指定洗牌。
- batch_sampler:類似於sampler,但每次返回一批索引。與batch_size、shuffle、sampler和drop_last相互排斥。
- num_workers:使用多進程加載的進程數,0代表不使用多線程
- collate_fn:如何將多個樣本數據拼接成一個batch,一般使用默認拼接方式
- pin_memory:是否將數據保存在pin memory區,pin memory中的數據轉到GPU會快一些
- drop_last:dataset中的數據個數可能不是batch_size的整數倍,drop_last為True會將多出來不足一個batch的數據丟棄
返回:數據加載器
案例:
# 實例化AudioDataset對象 train_set = AudioDataset("./p225", sr=16000) train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader): wav_data, wav_name = data print(wav_data.shape) # torch.Size([8, 8192]) print(i, wav_name) # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav', # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
我們來吃幾個栗子消化一下:
栗子1
這個例子就是本文一直舉例的,栗子1只是合並了一下而已
文件目錄結構
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:讀取p225文件夾中的音頻數據
import fnmatch import os import librosa import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader class Aduio_DataLoader(Dataset): def __init__(self, data_folder, sr=16000, dimension=8192): self.data_folder = data_folder self.sr = sr self.dim = dimension # 獲取音頻名列表 self.wav_list = [] for root, dirnames, filenames in os.walk(data_folder): for filename in fnmatch.filter(filenames, "*.wav"): # 實現列表特殊字符的過濾或篩選,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root, filename)) def __getitem__(self, item): # 讀取一個音頻文件,返回每個音頻數據 filename = self.wav_list[item] print(filename) wb_wav, _ = librosa.load(filename, sr=self.sr) # 取 幀 if len(wb_wav) >= self.dim: max_audio_start = len(wb_wav) - self.dim audio_start = np.random.randint(0, max_audio_start) wb_wav = wb_wav[audio_start: audio_start + self.dim] else: wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant") return wb_wav, filename def __len__(self): # 音頻文件的總數 return len(self.wav_list) train_set = Aduio_DataLoader("./p225", sr=16000) train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader): wav_data, wav_name = data print(wav_data.shape) # torch.Size([8, 8192]) print(i, wav_name) # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav', # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
注意事項:
- 27-33行:每個音頻的長度不一樣,如果直接讀取數據返回出來的話,會造成維度不匹配而報錯,因此只能每次取一個音頻文件讀取一幀,這樣顯然並沒有用到所有的語音數據,
- 48行:我們在__getitem__中並沒有將numpy數組轉換為tensor格式,可是第48行顯示數據是tensor格式的。這里需要引起注意
栗子2
相比於案例1,案例二才是重點,因為我們不可能每次只從一音頻文件中讀取一幀,然后讀取另一個音頻文件,通常情況下,一段音頻有很多幀,我們需要的是按順序的讀取一個batch_size的音頻幀,先讀取第一個音頻文件,如果滿足一個batch,則不用讀取第二個batch,如果不足一個batch則讀取第二個音頻文件,來補充。
我給出以下幾種建議:
建議一:
如果你模型需要讀取的不是簡單的音頻,而是經過較復雜特征處理后的數據,特征處理還挺需要時間的,我建議你用這種方法
先按順序讀取每個音頻文件,以窗長8192、幀移4096對語音進行分幀,然后拼接。得到(幀數,幀長,1)(frame_num, frame_len, 1)的數組保存到h5中。然后用上面講到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 讀取數據。
具體實現代碼:
第一步:創建一個H5_generation腳本,讀取語音並進行特征處理,最后將特征轉換為h5格式文件。(大家根據自己的研究領域進行相應的特征提取,我這個是語音頻帶擴展的窄帶和寬帶特征提取代碼,你們能看懂我想要表達的思想就行):

# Author:凌逆戰 # -*- coding: utf-8 -*- """ 方法:重采樣,高頻部分不會恢復,時間維度對不上,因此在重采樣之前需要給原音頻裁切取整 得到訓練數據為8000Hz,Ground True為16kHz。 """ import fnmatch import os import h5py import librosa import argparse import numpy as np parser = argparse.ArgumentParser() parser.add_argument('--sr', type=int, default=16000, help='音頻采樣率') parser.add_argument('--wav_dir', default="F:/dataset/VCTK-Corpus/wav48/p225", help='存放wav文件的路徑') parser.add_argument('--h5_dir', default="./single_speaker225_resample_r=2.h5", help='輸出 h5存檔的路徑') parser.add_argument('--scale', type=int, default=2, help='縮放因子') # 2、4、6 parser.add_argument('--dimension', type=int, default=8192, help='patch的維度') parser.add_argument('--stride', type=int, default=4096, help='提取patch時候的步幅') parser.add_argument('--batch_size', type=int, default=64, help='我們產生的 patches 是batch size的倍數') args = parser.parse_args() # 如果是TIMIT數據集 # train_set_shape:(48576, 8192, 1) # test_set_shape:(17728, 8192, 1) # python data_preprocess_resample.py --wav_dir "F:/dataset/TIMIT/TRAIN" --h5_dir "./TIMIT_resample_train_r=2.h5" # python data_preprocess_resample.py --wav_dir "F:/dataset/TIMIT/TEST" --h5_dir "./TIMIT_resample_test_r=2.h5" def preprocess(args, h5_file, save_wav): # 列出所有要處理的文件 列表 wav_list = [] for root, dirnames, filenames in os.walk(args.wav_dir): for filename in fnmatch.filter(filenames, "*.wav"): # 實現列表特殊字符的過濾或篩選,返回符合匹配“.wav”字符列表 wav_list.append(os.path.join(root, filename)) num_files = len(wav_list) # num_files音頻文件的個數 print("音頻的個數為:", num_files) # patches to extract and their size / 要提取的補丁及其大小 dim = args.dimension # patch的維度 default=8192 wb_stride = args.stride # 提取patch時候的步幅 default=3200 wb_patches = list() # 寬帶音頻補丁空列表 nb_patches = list() # 窄帶音頻補丁空列表 for j, wav_path in enumerate(wav_list): if j % 10 == 0: # 每隔10次打印一下文件的索引和文件路徑名 print('%d/%d' % (j, num_files)) wb_wav, _ = librosa.load(wav_path, sr=args.sr) # 加載音頻文件 采樣率 sr = 16000 # 裁剪,使其與縮放比率一起工作,結果:能被縮放比例整除,因為不能整除的已經被減去了 wav_len = len(wb_wav) wb_wav = wb_wav[: wav_len - (wav_len % args.scale)] # 生成低分辨率版本 nb_wav = librosa.core.resample(wb_wav, args.sr, args.sr / args.scale) # 下采樣率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, args.sr / args.scale, args.sr) # 上采樣率 8000-->16000,並不恢復高頻部分 # 生成補丁 max_i = len(wb_wav) - dim + 1 for i in range(0, max_i, wb_stride): wb_patch = np.array(wb_wav[i: i + dim]) nb_patch = np.array(nb_wav[i: i + dim]) wb_patches.append(wb_patch.reshape((dim, 1))) nb_patches.append(nb_patch.reshape((dim, 1))) # 裁剪補丁,使其成為小批量的倍數 num_wb_patches = len(wb_patches) num_nb_patches = len(nb_patches) print("num_wb_patches", num_wb_patches) # 852 print("num_nb_patches", num_nb_patches) # 852 print('batch_size:', args.batch_size) # batch_size: 64 # num_wb_patches要能夠被batch整除,保留能夠被整除的,這樣才能保證每個樣本都能被訓練到 num_to_keep_wb = num_wb_patches // args.batch_size * args.batch_size wb_patches = np.array(wb_patches[:num_to_keep_wb]) num_to_keep_nb = num_nb_patches // args.batch_size * args.batch_size nb_patches = np.array(nb_patches[:num_to_keep_nb]) print('hr_patches shape:', wb_patches.shape) # (832, 8192, 1) print('lr_patches shape:', nb_patches.shape) # (832, 8192, 1) # 創建 hdf5 文件 data_set = h5_file.create_dataset('data', nb_patches.shape, np.float32) # lr label_set = h5_file.create_dataset('label', wb_patches.shape, np.float32) # hr data_set[...] = nb_patches # ...代替了前面兩個冒號, data_set[...]=data_set[:,:] label_set[...] = wb_patches if save_wav: librosa.output.write_wav('resample_train_wb.wav', wb_patches[40].flatten(), args.sr, norm=False) librosa.output.write_wav('resample_train_nb.wav', nb_patches[40].flatten(), args.sr, norm=False) print(wb_patches[40].shape) # (8192, 1) print(nb_patches[40].shape) # (8192, 1) print('保存了兩個示例') if __name__ == '__main__': # 創造訓練 with h5py.File(args.h5_dir, 'w') as f: preprocess(args, f, save_wav=True)
第二步:通過Dataset從h5格式文件中讀取數據
import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader import h5py def load_h5(h5_path): # load training data with h5py.File(h5_path, 'r') as hf: print('List of arrays in input file:', hf.keys()) X = np.array(hf.get('data'), dtype=np.float32) Y = np.array(hf.get('label'), dtype=np.float32) return X, Y class AudioDataset(Dataset): """數據加載器""" def __init__(self, data_folder): self.data_folder = data_folder self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1) def __getitem__(self, item): # 返回一個音頻數據 X = self.X[item] Y = self.Y[item] return X, Y def __len__(self): return len(self.X) train_set = AudioDataset("./speaker225_resample_train.h5") train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for (i, wav_data) in enumerate(train_loader): X, Y = wav_data print(i, X.shape) # 0 torch.Size([64, 8192, 1]) # 1 torch.Size([64, 8192, 1]) # ...
- 優點:我把復雜的操作統一讓H5_generation.py文件來執行,模型訓練的時候直接讀取H5文件就行,不用在訓練模型的時候再進行特征提取,一勞永逸,節省時間。
- 缺點:最后能夠一步解決就最好了
我嘗試在__init__中生成h5文件,但是會導致內存爆炸,就很奇怪,因此我只好分開了
建議二:
如果你的模型輸入就是語音波形,或者特征處理非常簡單,我強烈建議你一步到位,不要去什么生成h5文件,
import os import time import numpy as np from torch.utils.data import Dataset, DataLoader import librosa class AudioData(Dataset): def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"): super(AudioData, self).__init__() self.dimension = dimension self.stride = stride self.scale = scale self.fs = fs self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)] self.wb_list = [] self.split() def get_nb(self, wb_wav): nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采樣率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采樣率 8000-->16000,並不恢復高頻部分 return nb_wav def split(self): for wav_path in self.wavs_path: wav, _ = librosa.load(path=wav_path, sr=self.fs) wav_length = len(wav) # 音頻長度 if wav_length < self.stride: # 如果語音長度小於4096 continue if wav_length < self.dimension: # 如果語音長度小於8192 diffe = self.dimension - wav_length wb_wav = np.pad(wav, (0, diffe), mode="constant") self.wb_list.append(wb_wav) else: # 如果音頻大於 8192 start_index = 0 while True: if start_index + self.dimension > wav_length: break wb_frame = wav[start_index:start_index + self.dimension] self.wb_list.append(wb_frame) start_index += self.stride def __len__(self): return len(self.wb_list) def __getitem__(self, index): return self.wb_list[index], self.get_nb(self.wb_list[index]) if __name__ == "__main__": start_time = time.time() data = AudioData() print(len(data)) # 3420 train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True) end_time = time.time() print("用了%d的時間" % (end_time-start_time)) # 24秒 for wb, nb in train_loader: print("寬帶", wb.shape) # torch.Size([32, 8192]) print("窄帶", nb.shape) # torch.Size([32, 8192]) break
- 優點:一步到位
- 缺點:每次實例化Dataset都要較長時間,程序允許完后,內存就釋放了,下次還需要又要從頭開始。
建議二的低效版:
看完了建議二,不看這個版本也行,但是為了讓大家思考如果更加高效的

# Author:凌逆戰 # -*- coding:utf-8 -*- """ 作用: """ import os import time import numpy as np from torch.utils.data import Dataset, DataLoader import librosa class AudioData(Dataset): def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"): super(AudioData, self).__init__() self.dimension = dimension self.stride = stride self.scale = scale self.fs = fs self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)] self.wb_list = [] self.nb_list = [] self.preprocess() def get_nb(self, wb_wav): nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采樣率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采樣率 8000-->16000,並不恢復高頻部分 return nb_wav def preprocess(self): for wav_path in self.wavs_path: wav, _ = librosa.load(path=wav_path, sr=self.fs) wav_length = len(wav) # 音頻長度 if wav_length < self.stride: # 如果語音長度小於4096 continue if wav_length < self.dimension: # 如果語音長度小於8192 diffe = self.dimension - wav_length wb_wav = np.pad(wav, (0, diffe), mode="constant") nb_wav = self.get_nb(wb_wav) self.wb_list.append(wb_wav) self.nb_list.append(nb_wav) else: # 如果音頻大於 8192 start_index = 0 while True: if start_index + self.dimension > wav_length: break wb_frame = wav[start_index:start_index + self.dimension] nb_frame = self.get_nb(wb_frame) self.wb_list.append(wb_frame) self.nb_list.append(nb_frame) start_index += self.stride def __len__(self): return len(self.wb_list) def __iter__(self): for index in range(len(self.wb_list)): yield self.wb_list[index], self.nb_list[index] def __getitem__(self, index): return self.wb_list[index], self.nb_list[index] if __name__ == "__main__": start_time = time.time() data = AudioData() print(len(data)) # 3420 train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True) end_time = time.time() print("用了%d的時間" % (end_time-start_time)) # 61秒 for wb, nb in train_loader: print("寬帶", wb.shape) print("窄帶", nb.shape) break
這個方法用了61秒完成數據讀取,原因是什么大家可以自己去思考,不建議用這個方法
參考
pytorch學習(四)—自定義數據集(講的比較詳細)