数据样本处理的代码可能会变得杂乱且难以维护,因此理想状态下我们应该将模型训练的代码和数据集代码分开封装,以获得更好的代码可读性和模块化代码。
(资料图片)
PyTorch 提供了两个基本方法 torch.utils.data.DataLoader
和torch.utils.data.Dataset
可以让你预加载数据集或者你的数据。
Dataset
存储样本及其相关的标签, DataLoader
封装了关于 Dataset
的迭代器,让我们可以方便地读取样本。
PyTorch库中也提供了一些常用的数据集可以方便用户做预加载可以通过torch.utils.data.Dataset
调用,还提供了一些对应数据集的方法。它们可以用于模型的原型和基准测试。
详细可以戳这里:
Image Datasets,Text Datasets,Audio Datasets。接下来我们看一下怎么从TorchVision加载Fashion-MNIST数据集。
Fashion-MNIST是Zalando的一个数据集,包含6万个训练样例和1万个测试样例。
每个样例由两部分组成,一个28×28灰度图像和一个十分类标签中的某一个标签。
我们要加载 FashionMNIST Dataset需要用到以下几个参数:
root
数据集的存储地址train
指定你要取训练集还是测试集download=True
如果你指定的 root
中没有数据集,会自动从网上下载数据集transform
、 target_transform
指定特征和标签转换下边这段代码是取FashionMNIST的训练集和测试集,root设置了一个data文件,运行下边这段代码以后你可以看到当前目录下边应该多了一个data文件夹,里边就是FashionMNIST数据集文件了。
import torchfrom torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor())test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor())复制代码
我们可以像列表索引一样查看Datasets
。可以使用matplotlib
可视化我们的数据集。
其他代码解析看注释。
至于画子图有两个方法,二者的区别仅在于一个面向方法,一个面向对象,别的完全一样。
subplotfigure = plt.figure() cols, rows = 3, 3 for i in range(1, cols * rows + 1): plt.subplot(rows, cols, i) plt.show()复制代码add_subplot
figure = plt.figure()cols, rows = 3, 3for i in range(1, cols * rows + 1): figure.subplot(rows, cols, i)plt.show()复制代码
labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() # 从数据集中随机采样 img, label = training_data[sample_idx] # 取得数据集的图和标签 figure.add_subplot(rows, cols, i) # 画子图,也可以plt.subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") # 是黑白图,这里做一个维度压缩,把1通道的1压缩掉plt.show()复制代码
最后随机采样的结果大概是这样的:
Dataset
可以检索我们数据集中一个样本的特征和标签。但是在训练模型的时候,我们通常希望数据以小批量(minibatch)的方式作为输入,在每个epoch中重新调整数据以防止过拟合,并且还能使用Python的multiprocessing
加速数据检索。
DataLoader
是一个迭代器,将刚才提到的复杂方法抽象成简单的API。
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)复制代码
我们已经将数据集加载到DataLoader
中,并可以根据需要迭代数据集。
下面的每次迭代返回一个批量数据的train_features
和train_labels
(分别包含batch_size=64
个特征和标签)。
因为我们指定了shuffle=True
,在遍历所有批量之后,数据会被打乱(要对数据加载顺序进行更细粒度的控制,戳这里pytorch.org/docs/stable… 。
# Display image and label.train_features, train_labels = next(iter(train_dataloader))print(f"Feature batch shape: {train_features.size()}")print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"Label: {label}")复制代码
自定义Dataset类必须实现三个函数:__init__
, __len__
和__getitem__
。看看这个FashionMNIST图像存储在img_dir目录中,它们的标签单独存储在CSV文件annotations_file中。在下一节我们详细分析一下每个函数中发生的事情。
import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset(Dataset): def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label复制代码
__init__
函数在实例化Dataset对象时运行一次,帮我们初始化一个目录,其中包含图像、注释文件和两个变换(下一节将详细介绍)。
The labels.csv file looks like:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform复制代码
__len__
方法返回我们数据集中的样本数量。
def __len__(self): return len(self.img_labels)复制代码
__getitem__
函数当你给定一个索引idx
的时候,用于加载并返回样本。
基于索引,该函数去寻找图像在磁盘上的位置,使用read_image
将其转换为一个张量,从self
中的csv数据中检索相应的标签img_labels
,调用它们上的变换函数(如果适用),并返回一个元组,元组中是图像的张量和对应的标签。
def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) return image, label