pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合
- torch.utils.data.Dataset:所有继承他的子类都应该重写 __len()__ , __getitem()__ 这两个方法
- __len()__ :返回数据集中数据的数量
- __getitem()__ :返回支持下标索引方式获取的一个数据
- torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....
第一步
自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:
- __len()__:读取数据,返回数据和标签
- __getitem()__:返回数据集的长度
from torch.utils.data import Dataset class AudioDataset(Dataset): def __init__(self, ...): """类的初始化""" pass def __getitem__(self, item): """每次怎么读数据,返回数据和标签""" return data, label def __len__(self): """返回整个数据集的长度""" return total
注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本
案例:
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
class AudioDataset(Dataset): def __init__(self, data_folder, sr=16000, dimension=8192): self.data_folder = data_folder self.sr = sr self.dim = dimension # 获取音频名列表 self.wav_list = [] for root, dirnames, filenames in os.walk(data_folder): for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root, filename)) def __getitem__(self, item): # 读取一个音频文件,返回每个音频数据 filename = self.wav_list[item] wb_wav, _ = librosa.load(filename, sr=self.sr) # 取 帧 if len(wb_wav) >= self.dim: max_audio_start = len(wb_wav) - self.dim audio_start = np.random.randint(0, max_audio_start) wb_wav = wb_wav[audio_start: audio_start + self.dim] else: wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant") return wb_wav, filename def __len__(self): # 音频文件的总数 return len(self.wav_list)
注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
第二步
实例化 Dataset 对象
Dataset= AudioDataset("./p225", sr=16000)
如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作
# 实例化AudioDataset对象 train_set = AudioDataset("./p225", sr=16000) for i, data in enumerate(train_set): wb_wav, filname = data print(i, wb_wav.shape, filname) if i == 3: break # 0 (8192,) ./p225\p225_001.wav # 1 (8192,) ./p225\p225_002.wav # 2 (8192,) ./p225\p225_003.wav # 3 (8192,) ./p225\p225_004.wav
第三步
如果想要通过batch读取数据,需要使用DataLoader进行包装
为何要使用DataLoader?
- 深度学习的输入是mini_batch形式
- 样本加载时候可能需要随机打乱顺序,shuffle操作
- 样本加载需要采用多线程
pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
参数:
- dataset:加载的数据集(Dataset对象)
- batch_size:每个批次要加载多少个样本(默认值:1)
- shuffle:每个epoch是否将数据打乱
- sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
- batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
- num_workers:使用多进程加载的进程数,0代表不使用多线程
- collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
- pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
返回:数据加载器
案例:
# 实例化AudioDataset对象 train_set = AudioDataset("./p225", sr=16000) train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader): wav_data, wav_name = data print(wav_data.shape) # torch.Size([8, 8192]) print(i, wav_name) # (\'./p225\\p225_293.wav\', \'./p225\\p225_156.wav\', \'./p225\\p225_277.wav\', \'./p225\\p225_210.wav\', # \'./p225\\p225_126.wav\', \'./p225\\p225_021.wav\', \'./p225\\p225_257.wav\', \'./p225\\p225_192.wav\')
我们来吃几个栗子消化一下:
栗子1
这个例子就是本文一直举例的,栗子1只是合并了一下而已
文件目录结构
- p225
- ***.wav
- ***.wav
- ***.wav
- ...
- dataset.py
目的:读取p225文件夹中的音频数据
import fnmatch import os import librosa import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader class Aduio_DataLoader(Dataset): def __init__(self, data_folder, sr=16000, dimension=8192): self.data_folder = data_folder self.sr = sr self.dim = dimension # 获取音频名列表 self.wav_list = [] for root, dirnames, filenames in os.walk(data_folder): for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表 self.wav_list.append(os.path.join(root, filename)) def __getitem__(self, item): # 读取一个音频文件,返回每个音频数据 filename = self.wav_list[item] print(filename) wb_wav, _ = librosa.load(filename, sr=self.sr) # 取 帧 if len(wb_wav) >= self.dim: max_audio_start = len(wb_wav) - self.dim audio_start = np.random.randint(0, max_audio_start) wb_wav = wb_wav[audio_start: audio_start + self.dim] else: wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant") return wb_wav, filename def __len__(self): # 音频文件的总数 return len(self.wav_list) train_set = Aduio_DataLoader("./p225", sr=16000) train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader): wav_data, wav_name = data print(wav_data.shape) # torch.Size([8, 8192]) print(i, wav_name) # (\'./p225\\p225_293.wav\', \'./p225\\p225_156.wav\', \'./p225\\p225_277.wav\', \'./p225\\p225_210.wav\', # \'./p225\\p225_126.wav\', \'./p225\\p225_021.wav\', \'./p225\\p225_257.wav\', \'./p225\\p225_192.wav\')
注意事项:
- 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
- 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意
栗子2
相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。
我给出以下几种建议:
建议一:
如果你模型需要读取的不是简单的音频,而是经过较复杂特征处理后的数据,特征处理还挺需要时间的,我建议你用这种方法
先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。
具体实现代码:
第一步:创建一个H5_generation脚本,读取语音并进行特征处理,最后将特征转换为h5格式文件。(大家根据自己的研究领域进行相应的特征提取,我这个是语音频带扩展的窄带和宽带特征提取代码,你们能看懂我想要表达的思想就行):
# Author:凌逆战 # -*- coding: utf-8 -*- """ 方法:重采样,高频部分不会恢复,时间维度对不上,因此在重采样之前需要给原音频裁切取整 得到训练数据为8000Hz,Ground True为16kHz。 """ import fnmatch import os import h5py import librosa import argparse import numpy as np parser = argparse.ArgumentParser() parser.add_argument(\'--sr\', type=int, default=16000, help=\'音频采样率\') parser.add_argument(\'--wav_dir\', default="F:/dataset/VCTK-Corpus/wav48/p225", help=\'存放wav文件的路径\') parser.add_argument(\'--h5_dir\', default="./single_speaker225_resample_r=2.h5", help=\'输出 h5存档的路径\') parser.add_argument(\'--scale\', type=int, default=2, help=\'缩放因子\') # 2、4、6 parser.add_argument(\'--dimension\', type=int, default=8192, help=\'patch的维度\') parser.add_argument(\'--stride\', type=int, default=4096, help=\'提取patch时候的步幅\') parser.add_argument(\'--batch_size\', type=int, default=64, help=\'我们产生的 patches 是batch size的倍数\') args = parser.parse_args() # 如果是TIMIT数据集 # train_set_shape:(48576, 8192, 1) # test_set_shape:(17728, 8192, 1) # python data_preprocess_resample.py --wav_dir "F:/dataset/TIMIT/TRAIN" --h5_dir "./TIMIT_resample_train_r=2.h5" # python data_preprocess_resample.py --wav_dir "F:/dataset/TIMIT/TEST" --h5_dir "./TIMIT_resample_test_r=2.h5" def preprocess(args, h5_file, save_wav): # 列出所有要处理的文件 列表 wav_list = [] for root, dirnames, filenames in os.walk(args.wav_dir): for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表 wav_list.append(os.path.join(root, filename)) num_files = len(wav_list) # num_files音频文件的个数 print("音频的个数为:", num_files) # patches to extract and their size / 要提取的补丁及其大小 dim = args.dimension # patch的维度 default=8192 wb_stride = args.stride # 提取patch时候的步幅 default=3200 wb_patches = list() # 宽带音频补丁空列表 nb_patches = list() # 窄带音频补丁空列表 for j, wav_path in enumerate(wav_list): if j % 10 == 0: # 每隔10次打印一下文件的索引和文件路径名 print(\'%d/%d\' % (j, num_files)) wb_wav, _ = librosa.load(wav_path, sr=args.sr) # 加载音频文件 采样率 sr = 16000 # 裁剪,使其与缩放比率一起工作,结果:能被缩放比例整除,因为不能整除的已经被减去了 wav_len = len(wb_wav) wb_wav = wb_wav[: wav_len - (wav_len % args.scale)] # 生成低分辨率版本 nb_wav = librosa.core.resample(wb_wav, args.sr, args.sr / args.scale) # 下采样率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, args.sr / args.scale, args.sr) # 上采样率 8000-->16000,并不恢复高频部分 # 生成补丁 max_i = len(wb_wav) - dim + 1 for i in range(0, max_i, wb_stride): wb_patch = np.array(wb_wav[i: i + dim]) nb_patch = np.array(nb_wav[i: i + dim]) wb_patches.append(wb_patch.reshape((dim, 1))) nb_patches.append(nb_patch.reshape((dim, 1))) # 裁剪补丁,使其成为小批量的倍数 num_wb_patches = len(wb_patches) num_nb_patches = len(nb_patches) print("num_wb_patches", num_wb_patches) # 852 print("num_nb_patches", num_nb_patches) # 852 print(\'batch_size:\', args.batch_size) # batch_size: 64 # num_wb_patches要能够被batch整除,保留能够被整除的,这样才能保证每个样本都能被训练到 num_to_keep_wb = num_wb_patches // args.batch_size * args.batch_size wb_patches = np.array(wb_patches[:num_to_keep_wb]) num_to_keep_nb = num_nb_patches // args.batch_size * args.batch_size nb_patches = np.array(nb_patches[:num_to_keep_nb]) print(\'hr_patches shape:\', wb_patches.shape) # (832, 8192, 1) print(\'lr_patches shape:\', nb_patches.shape) # (832, 8192, 1) # 创建 hdf5 文件 data_set = h5_file.create_dataset(\'data\', nb_patches.shape, np.float32) # lr label_set = h5_file.create_dataset(\'label\', wb_patches.shape, np.float32) # hr data_set[...] = nb_patches # ...代替了前面两个冒号, data_set[...]=data_set[:,:] label_set[...] = wb_patches if save_wav: librosa.output.write_wav(\'resample_train_wb.wav\', wb_patches[40].flatten(), args.sr, norm=False) librosa.output.write_wav(\'resample_train_nb.wav\', nb_patches[40].flatten(), args.sr, norm=False) print(wb_patches[40].shape) # (8192, 1) print(nb_patches[40].shape) # (8192, 1) print(\'保存了两个示例\') if __name__ == \'__main__\': # 创造训练 with h5py.File(args.h5_dir, \'w\') as f: preprocess(args, f, save_wav=True)
第二步:通过Dataset从h5格式文件中读取数据
import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader import h5py def load_h5(h5_path): # load training data with h5py.File(h5_path, \'r\') as hf: print(\'List of arrays in input file:\', hf.keys()) X = np.array(hf.get(\'data\'), dtype=np.float32) Y = np.array(hf.get(\'label\'), dtype=np.float32) return X, Y class AudioDataset(Dataset): """数据加载器""" def __init__(self, data_folder): self.data_folder = data_folder self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1) def __getitem__(self, item): # 返回一个音频数据 X = self.X[item] Y = self.Y[item] return X, Y def __len__(self): return len(self.X) train_set = AudioDataset("./speaker225_resample_train.h5") train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for (i, wav_data) in enumerate(train_loader): X, Y = wav_data print(i, X.shape) # 0 torch.Size([64, 8192, 1]) # 1 torch.Size([64, 8192, 1]) # ...
- 优点:我把复杂的操作统一让H5_generation.py文件来执行,模型训练的时候直接读取H5文件就行,不用在训练模型的时候再进行特征提取,一劳永逸,节省时间。
- 缺点:最后能够一步解决就最好了
我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了
建议二:
如果你的模型输入就是语音波形,或者特征处理非常简单,我强烈建议你一步到位,不要去什么生成h5文件,
import os import time import numpy as np from torch.utils.data import Dataset, DataLoader import librosa class AudioData(Dataset): def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"): super(AudioData, self).__init__() self.dimension = dimension self.stride = stride self.scale = scale self.fs = fs self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)] self.wb_list = [] self.split() def get_nb(self, wb_wav): nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采样率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采样率 8000-->16000,并不恢复高频部分 return nb_wav def split(self): for wav_path in self.wavs_path: wav, _ = librosa.load(path=wav_path, sr=self.fs) wav_length = len(wav) # 音频长度 if wav_length < self.stride: # 如果语音长度小于4096 continue if wav_length < self.dimension: # 如果语音长度小于8192 diffe = self.dimension - wav_length wb_wav = np.pad(wav, (0, diffe), mode="constant") self.wb_list.append(wb_wav) else: # 如果音频大于 8192 start_index = 0 while True: if start_index + self.dimension > wav_length: break wb_frame = wav[start_index:start_index + self.dimension] self.wb_list.append(wb_frame) start_index += self.stride def __len__(self): return len(self.wb_list) def __getitem__(self, index): return self.wb_list[index], self.get_nb(self.wb_list[index]) if __name__ == "__main__": start_time = time.time() data = AudioData() print(len(data)) # 3420 train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True) end_time = time.time() print("用了%d的时间" % (end_time-start_time)) # 24秒 for wb, nb in train_loader: print("宽带", wb.shape) # torch.Size([32, 8192]) print("窄带", nb.shape) # torch.Size([32, 8192]) break
- 优点:一步到位
- 缺点:每次实例化Dataset都要较长时间,程序允许完后,内存就释放了,下次还需要又要从头开始。
建议二的低效版:
看完了建议二,不看这个版本也行,但是为了让大家思考如果更加高效的
# Author:凌逆战 # -*- coding:utf-8 -*- """ 作用: """ import os import time import numpy as np from torch.utils.data import Dataset, DataLoader import librosa class AudioData(Dataset): def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"): super(AudioData, self).__init__() self.dimension = dimension self.stride = stride self.scale = scale self.fs = fs self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)] self.wb_list = [] self.nb_list = [] self.preprocess() def get_nb(self, wb_wav): nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采样率 16000-->8000 nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采样率 8000-->16000,并不恢复高频部分 return nb_wav def preprocess(self): for wav_path in self.wavs_path: wav, _ = librosa.load(path=wav_path, sr=self.fs) wav_length = len(wav) # 音频长度 if wav_length < self.stride: # 如果语音长度小于4096 continue if wav_length < self.dimension: # 如果语音长度小于8192 diffe = self.dimension - wav_length wb_wav = np.pad(wav, (0, diffe), mode="constant") nb_wav = self.get_nb(wb_wav) self.wb_list.append(wb_wav) self.nb_list.append(nb_wav) else: # 如果音频大于 8192 start_index = 0 while True: if start_index + self.dimension > wav_length: break wb_frame = wav[start_index:start_index + self.dimension] nb_frame = self.get_nb(wb_frame) self.wb_list.append(wb_frame) self.nb_list.append(nb_frame) start_index += self.stride def __len__(self): return len(self.wb_list) def __iter__(self): for index in range(len(self.wb_list)): yield self.wb_list[index], self.nb_list[index] def __getitem__(self, index): return self.wb_list[index], self.nb_list[index] if __name__ == "__main__": start_time = time.time() data = AudioData() print(len(data)) # 3420 train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True) end_time = time.time() print("用了%d的时间" % (end_time-start_time)) # 61秒 for wb, nb in train_loader: print("宽带", wb.shape) print("窄带", nb.shape) break
这个方法用了61秒完成数据读取,原因是什么大家可以自己去思考,不建议用这个方法
参考
pytorch学习(四)—自定义数据集(讲的比较详细)