PyTorch的数据读取及transforms运行机制
PyTorch的数据读取及transforms运行机制
PyTorch的数据读取核心就是DataLoader。它分为两个子模块,Sample和DataSet。Sample的功能是生成索引——Index即样本的序号;DataSet则是根据索引去读取图片等数据以及所属标签。
PyTorch的数据读取及transforms运行机制
一、数据读取
1.1 torch.utils.data.DataLoader
torch.utils.data.DataLoader(dataset, batch_size=1,shuffle=False,sample=None, batch_sample=None, num_workers=0,collate_fn=None, pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)
功能:构建可迭代的数据读取器。
dataset:Dataset类,决定数据从哪里读取以及如何读取。
batch_size:批大小
num_workers:是否多进程读取数据
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batch_zise整除时,是否舍弃最后一批数据。
Epoch、Iteration、Batchsize 之间的关系
Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个Iteration
Batchsize:批大小,决定一个Epoch有多少个Iteration
二、torch.utils.data.Dataset
torch.utils.data.Dataset(object)
功能:Dataset抽象类,所有定义的Dataset需要继承它,并且复写__getitem__()
getitem: 接受一个索引,返回一个样本
PyTorch 的数据读取流程:
二、Transforms运行机制
torchvision:计算机视觉工具包
2.1 torchvision.transforms
常用的图像预处理方法
数据中心化
数据标准化
缩放
裁剪
旋转
翻转
填充
噪声添加
灰度变换
线性变换
亮度、饱和度及对比度变换
通过上面的DataLoader以及Dataset将数据读取进来,然后通过transform.compose()将数据按compose定义的缩放、标准化、裁剪、翻转、填充等顺序进行变换,最后得出可训练的数据图片及标签。
transforms.Normalize(mean, std, inplace=False)
功能:逐channel的对图像进行标准化
output = (input - mean) / std
mean:各通道的均值
sts : 各通道的标准差
inplace:是否原地操作, 替换原始数据
2.2 torchvision.dataset
常用数据集的dataset实现,有MNIST、CIFAR-10,ImageNet等
2.3 torchvision.model
常用的模型预训练,AlexNet、VGG、ResNet、GoogLeNet等