自定义数据集
自定义数据集
在使用pytorch进行深度学习训练时,难以避免的使用自己的私有数据集。这时就需要进行手动封装,一个良好的数据集是深度学习完美的开始
定义数据集方法
自定义数据需要使用torch.utils.data.Dataset
类,这是自定义数据集的基础
自定义数据无非就三要数,路径,标签处理及返回和数据大小返回
def __init__(self,root_dir,label_dir):
通常进行初始换,图像路径调用def __getitem__(self, item):
按照自己的需要进行指定要,在下面的例子中,就自定义数据返回了图像本身和标签(标签是文件夹) 【这部分是自定义数据的关键】def __len__(self):
定义len方法,一般不需要特殊定义1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34import torch
import torchvision
from torch.utils.data import Dataset
import os
from PIL import Image
train_root = 'data/train'
train_ants_root = 'ants'
train_bess_root = 'bees'
class Mydata(Dataset):
def __init__(self,root_dir,label_dir):
#在这里进行初始化,一般是初始化文件路径或文件列表
super().__init__()
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, item):
# 1. 按照index,读取文件中对应的数据 (读取一个数据!!!!我们常读取的数据是图片,一般我们送入模型的数据成批的,但在这里只是读取一张图片,成批后面会说到)
# 2. 对读取到的数据进行数据增强 (数据增强是深度学习中经常用到的,可以提高模型的泛化能力)(torchvision.transforms.Compose)
# 3. 返回数据对 (一般我们要返回 图片,对应的标签)
self.img = os.path.join(self.root_dir,self.root_dir,self.img_path[item])
img = Image.open(self.img)
img = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
if self.label_dir == 'ants':
label = torch.tensor(1.0)
else:
label = torch.tensor(0)
return img, label
def __len__(self):
return len(self.img_path)
后记
在许多书中和教程中很少提到自定义数据集这件事,导致产生了自定义数据这件事情比较难的刻板印象。但是实际系统的看了发现这点并非很难,定义复杂的数据集只要动手也不是问题。只要多看代码多动手一切困难都会解决的。
参考
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 楚天!