自定义数据集

在使用pytorch进行深度学习训练时,难以避免的使用自己的私有数据集。这时就需要进行手动封装,一个良好的数据集是深度学习完美的开始

定义数据集方法

自定义数据需要使用torch.utils.data.Dataset类,这是自定义数据集的基础
自定义数据无非就三要数,路径,标签处理及返回和数据大小返回

  1. def __init__(self,root_dir,label_dir): 通常进行初始换,图像路径调用
  2. def __getitem__(self, item): 按照自己的需要进行指定要,在下面的例子中,就自定义数据返回了图像本身和标签(标签是文件夹) 【这部分是自定义数据的关键】
  3. def __len__(self):定义len方法,一般不需要特殊定义
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    import torch
    import torchvision
    from torch.utils.data import Dataset
    import os
    from PIL import Image
    train_root = 'data/train'
    train_ants_root = 'ants'
    train_bess_root = 'bees'

    class Mydata(Dataset):
    def __init__(self,root_dir,label_dir):
    #在这里进行初始化,一般是初始化文件路径或文件列表
    super().__init__()
    self.root_dir = root_dir
    self.label_dir = label_dir
    self.path = os.path.join(self.root_dir,self.label_dir)
    self.img_path = os.listdir(self.path)
    def __getitem__(self, item):
    # 1. 按照index,读取文件中对应的数据 (读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
    # 2. 对读取到的数据进行数据增强 (数据增强是深度学习中经常用到的,可以提高模型的泛化能力)(torchvision.transforms.Compose)
    # 3. 返回数据对 (一般我们要返回 图片,对应的标签)
    self.img = os.path.join(self.root_dir,self.root_dir,self.img_path[item])
    img = Image.open(self.img)
    img = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
    ])
    if self.label_dir == 'ants':
    label = torch.tensor(1.0)
    else:
    label = torch.tensor(0)
    return img, label
    def __len__(self):
    return len(self.img_path)

后记

在许多书中和教程中很少提到自定义数据集这件事,导致产生了自定义数据这件事情比较难的刻板印象。但是实际系统的看了发现这点并非很难,定义复杂的数据集只要动手也不是问题。只要多看代码多动手一切困难都会解决的。

参考

  1. pytorch技巧 五: 自定义数据集 torch.utils.data.DataLoader 及Dataset的使用
  2. 基于Pytorch的自制蚂蚁蜜蜂数据集的训练与识别