PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:

  1. torchvision.datasets
  2. torchvision.models
  3. torchvision.transforms

这3个子包的具体介绍可以参考官网:http://pytorch.org/docs/master/torchvision/index.html。具体代码可以参考github:https://github.com/pytorch/vision/tree/master/torchvision。

torchvision.transforms

这个包中包含resize、crop等常见的data augmentation操作,基本上PyTorch中的data augmentation操作都可以通过该接口实现。该包主要包含两个脚本:transformas.py和functional.py,前者定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。

# 包含的augment 操作op

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop","TenCrop","LinearTransformation""ColorJitter", "RandomRotation", "RandomAffine", 
           "Grayscale", "RandomGrayscale"]

使用例子

import torchvision
import torch

train_augmentation = torchvision.transforms.Compose(
						[torchvision.transforms.Resize(256),
                         torchvision.transforms.RandomCrop(224),                                                                            
                         torchvision.transofrms.RandomHorizontalFlip(),
                         torchvision.transforms.ToTensor(),
                         torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
                         ])

Class custom_dataread(torch.utils.data.Dataset):
    def __init__():
        ...
    def __getitem__():
        # use self.transform for input image
    def __len__():
        ...

train_loader = torch.utils.data.DataLoader(
    	custom_dataread(transform=train_augmentation),
    	batch_size = batch_size, shuffle = True,
    	num_workers = workers, pin_memory = True)

data获取的时候,主要是继承torch.utils.data.Dataset

然后实现__len__getitem__,在len()中返回 数据每一个epoch的迭代步数,getitem 中实现数据读取、增强等操作

主要代码在transformas.py脚本中,这里仅介绍常见的data augmentation操作

Compose

传入一个list – [Transforms…]

class Compose(object):
    """Composes several transforms together. Args: 传入一个包含transform对象的list,会自动调用所有transform对象的call函数 Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

ToTensor

在PyTorch中常用PIL库来读取图像数据,可以将 numpy.array 或者 PIL.Image 转换为Torch.tensor(),并且转换 HWC to CHW,网络最后接收到的是[b,c,h,w]

return img.float().div(255)

会压缩到[0,1]之间,所以在反转回来的时候,需要拉伸回原来的尺度,另外要强调的是在做Normalize、resize或crop之前必须要把PIL.Image转成Tensor

 # 实现一部分
 
  if isinstance(pic, np.ndarray):
        # handle numpy array
        if pic.ndim == 2:
            pic = pic[:, :, None]

        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
            ......
            

ToPILImage

顾名思义是从Tensor到PIL.Image的过程,和前面ToTensor类的相反的操作

Normalize

Normalize类是做数据归一化的,一般都会对输入数据做这样的操作,公式也在注释中给出了,比较容易理解。前面提到在调用Normalize的时候,输入得是Tensor,这个从__call__方法的输入也可以看出来了。


# 调用实例 torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
# mean,std 一般为 对应input_tensor通道数的均值和方差

def normalize(tensor, mean, std, inplace=False):
    """Normalize a tensor image with mean and standard deviation. .. note:: This transform acts out of place by default, i.e., it does not mutates the input tensor. Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channely. Returns: Tensor: Normalized Tensor image. """
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')

    if not inplace:
        tensor = tensor.clone()

    mean = torch.tensor(mean, dtype=torch.float32)
    std = torch.tensor(std, dtype=torch.float32)
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
    return tensor

Resize

输入参数可以是int 或者 (int,int) ,如果输入int,则会将input_image的短边resize到目标大小,然后长边则根据对应比例调整,size*h/w,图像的长宽比不变。如果输入是个(h,w)的序列,h和w都是int,则直接将输入图像resize到这个(h,w)尺寸,相当于force resize,所以一般最后图像的长宽比会变化,也就是图像内容被拉长或缩短。

# Resize
class Resize(object):
    """Resize the input PIL Image to the given size. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        return F.resize(img, self.size, self.interpolation)


  def resize(img, size, interpolation=Image.BILINEAR):
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size[::-1], interpolation)

Crop

  1. CenterCrop
  2. RandomCrop

CenterCrop输入目标尺寸(th,tw),然后在输入图的中心点 按照size进行裁剪,这种方法多次操作返回的结果是一样的
crop函数是输入裁剪区域的 左上角坐标和长宽 (i,j,h,w),CenterCrop计算的左上角坐标 int(round((h - th) / 2.))所以为中心点裁剪

RandomCrop相对于CenterCrop而言,就是在图片上随机裁剪,具体实现i = random.randint(0, h - th)和 j = random.randint(0, w -tw)


def center_crop(img, output_size):
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
    w, h = img.size
    th, tw = output_size
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    return crop(img, i, j, th, tw)


def crop(img, i, j, h, w):
    """Crop the given PIL Image. Args: img (PIL Image): Image to be cropped. i: Upper pixel coordinate. j: Left pixel coordinate. h: Height of the cropped image. w: Width of the cropped image. Returns: PIL Image: Cropped image. """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.crop((j, i, j + w, i + h))

RandomHorizontalFlip

是随机的图像水平翻转,通俗讲就是图像的左右对调。从该类中的__call__方法可以看出水平翻转的概率是0.5。

  1. RandomHorizontalFlip 水平翻转 img.transpose(Image.FLIP_LEFT_RIGHT)
  2. RandomVerticalFlip 上下翻转 img.transpose(Image.FLIP_TOP_BOTTOM)
def hflip(img):
    """Horizontally flip the given PIL Image. Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Horizontall flipped image. """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_LEFT_RIGHT)


def vflip(img):
    """Vertically flip the given PIL Image. Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Vertically flipped image. """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)
    
  顺时针旋转
      :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
      :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
      :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
      :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.

RandomResizedCrop

RandomResizedCrop类也是比较常用的,前面不管是CenterCrop还是RandomCrop,在crop的时候其尺寸是固定的,而这个类则是random size的crop, 先做crop 然后再做resize

  1. 该类主要用到3个参数:size、scale和ratio,总的来讲就是先做crop(用到scale和ratio),再resize到指定尺寸(用到size)
  2. 做crop的时候,其中心点坐标和长宽是由get_params方法得到的,在get_params方法中主要用到两个参数:scale和ratio,首先在scale限定的数值范围内随机生成一个数,用这个数乘以输入图像的面积作为crop后图像的面积;然后在ratio限定的数值范围内随机生成一个数,表示长宽的比值,根据这两个值就可以得到crop图像的长宽了。
  3. 至于crop图像的中心点坐标,也是类似RandomCrop类一样是随机生成的。

ColorJitter

主要是修改输入图像的4大参数值:brightness, contrast and saturation,hue,也就是亮度,对比度,饱和度和色度。可以根据注释来合理设置这4个参数。

RandomRotation


# angle 旋转角度
# resample 
# expand 是否裁剪
# center 旋转中心,默认图片中心
def rotate(img, angle, resample=False, expand=False, center=None):
    """Rotate the image by angle. Args: img (PIL Image): PIL Image to be rotated. angle (float or int): In degrees degrees counter clockwise order. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation. center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.rotate(angle, resample, expand, center)

# expand=True 关闭裁剪
img.rotate(45, expand=True)

Grayscale

是用来将输入图像转成灰度图的,这里根据参数num_output_channels的不同有两种转换方式。

。。。。未完待续