利用pytorch图像增广
图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。
图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。
简单说就是,通过一些技巧,让图像数据变多;
图像增广基于现有训练数据生成随机图像从而应对过拟合。
1 2 3 4 5 6 7 8 9 10
   | import sys from IPython import display import matplotlib.pyplot as plt %matplotlib inline import time import torch from torch import nn, optim from torch.utils.data import Dataset, DataLoader import torchvision from PIL import Image
   | 
 
1 2
   | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device)
  | 
 
cuda
1 2 3 4
   | def set_figsize(figsize=(3.5, 2.5)):     use_svg_display()          plt.rcParams['figure.figsize'] = figsize
   | 
 
1 2 3
   | def use_svg_display():     """Use svg format to display plot in jupyter"""     display.set_matplotlib_formats('svg')
   | 
 
1 2
   | img=Image.open('test.jpg') plt.imshow(img)
  | 
 
<matplotlib.image.AxesImage at 0x7fb7abf34d30>

1 2 3 4 5 6 7 8 9
   | def show_images(imgs, num_rows, num_cols, scale=2):     figsize = (num_cols * scale, num_rows * scale)     _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)     for i in range(num_rows):         for j in range(num_cols):             axes[i][j].imshow(imgs[i * num_cols + j])             axes[i][j].axes.get_xaxis().set_visible(False)             axes[i][j].axes.get_yaxis().set_visible(False)     return axes
   | 
 
1 2 3
   | def apply(img,aug,num_rows=2,num_cols=4,scale=1.5):     Y=[aug(img) for _ in range(num_rows*num_cols)]     show_images(Y,num_rows,num_cols,scale)
   | 
 
1
   | apply(img,torchvision.transforms.RandomHorizontalFlip())
   | 
 

1
   | apply(img,torchvision.transforms.RandomVerticalFlip())
   | 
 

1 2
   | shape_aug=torchvision.transforms.RandomResizedCrop(200,scale=(0.1,1))  apply(img,shape_aug)
   | 
 

1
   | apply(img,torchvision.transforms.ColorJitter(brightness=0.5)) 
   | 
 

1
   | apply(img,torchvision.transforms.ColorJitter(hue=0.5)) 
   | 
 

1
   | apply(img,torchvision.transforms.ColorJitter(contrast=0.5))  
   | 
 

1
   | apply(img,torchvision.transforms.ColorJitter(saturation=0.5))  
   | 
 

1 2 3 4 5
   | color_aug=torchvision.transforms.ColorJitter(brightness=0.5,                                              contrast=0.5,                                              saturation=0.5,                                              hue=0.5) apply(img,color_aug)
   | 
 

1 2 3
   | augs = torchvision.transforms.Compose([     torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug]) apply(img, augs)
   | 
 
