pytorch源码解读之torchvision.transforms
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。这3个子包的具体介绍可以参考官网:http://pytorch.org/docs/master/torchvision/index.html。具体代码可以参考github:https://github.com/pytorch/vision/tree/master/torchvision。
这篇博客介绍torchvision.transformas。torchvision.transforms这个包中包含resize、crop等常见的data
augmentation操作,基本上PyTorch中的data
augmentation操作都可以通过该接口实现。该包主要包含两个脚本:transformas.py和functional.py,前者定义了各种data
augmentation的类,在每个类中通过调用functional.py中对应的函数完成data
augmentation操作。
使用例子:
import torchvision import torch train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256), torchvision.transforms.RandomCrop(224), torchvision.transofrms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225]) ]) Class custom_dataread(torch.utils.data.Dataset): def __init__(): ... def __getitem__(): # use self.transform for input image def __len__(): ... train_loader = torch.utils.data.DataLoader( custom_dataread(transform=train_augmentation), batch_size = batch_size, shuffle = True, num_workers = workers, pin_memory = True)
这里定义了resize、crop、normalize等数据预处理操作,并最终作为数据读取类custom_dataread的一个参数传入,可以在内部方法getitem中实现数据增强操作。