数据增强
导入包
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
bs = 64
path = untar_data(URLs.PETS)/'images'
数据增强
- p_affine仿射变换发生概率
- p_lighting光照变换发生概率
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
p_affine=1., p_lighting=1.)
doc(get_transforms)
src = ImageList.from_folder(path).split_by_rand_pct(0.2, seed=2)
def get_data(size, bs, padding_mode='reflection'):
return (src.label_from_re(r'([^/]+)_\d+.jpg$')
.transform(tfms, size=size, padding_mode=padding_mode)
.databunch(bs=bs).normalize(imagenet_stats))
data = get_data(224, bs, 'zeros')
def _plot(i,j,ax):
x,y = data.train_ds[3]
x.show(ax, y=y)
plot_multi(_plot, 3, 3, figsize=(8,8))
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dmu3y32V-1585460002754)(output_11_0.png)]](https://i-blog.csdnimg.cn/blog_migrate/73f88a8c5881b27469e36091b5be1016.png)
data = get_data(224,bs)
plot_multi(_plot, 3, 3, figsize=(8,8))
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jJ7N8hvK-1585460002755)(output_13_0.png)]](https://i-blog.csdnimg.cn/blog_migrate/3cfcb37cdd37c0f74b7211c7fa7e3f8d.png)
训练模型
gc.collect()
learn = cnn_learner(data, models.resnet34, metrics=error_rate, bn_final=True)
learn.fit_one_cycle(3, slice(1e-2), pct_start=0.8)
epoch | train_loss | valid_loss | error_rate | time |
---|
0 | 1.580376 | 0.498507 | 0.094723 | 01:44 |
1 | 1.144880 | 0.336172 | 0.084574 | 01:39 |
2 | 0.935206 | 0.227192 | 0.069012 | 01:37 |
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-3), pct_start=0.8)
epoch | train_loss | valid_loss | error_rate | time |
---|
0 | 0.859577 | 0.233298 | 0.069012 | 02:10 |
1 | 0.785061 | 0.215938 | 0.056834 | 02:14 |
data = get_data(352,bs)
learn.data = data
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
epoch | train_loss | valid_loss | error_rate | time |
---|
0 | 0.402720 | 0.192638 | 0.046685 | 01:53 |
1 | 0.377641 | 0.191480 | 0.048038 | 01:48 |
learn.save('352')
卷积核
data = get_data(352,16)
learn = cnn_learner(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')
idx=0
x,y = data.valid_ds[idx]
x.show()
data.valid_ds.y[idx]
Category american_pit_bull_terrier
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U28ubqoF-1585460002755)(output_23_1.png)]](https://i-blog.csdnimg.cn/blog_migrate/eaf6c5d1f67529ef9f253876bccbeba8.png)
k = tensor([
[0. ,-5/3,1],
[-5/3,-5/3,1],
[1. ,1 ,1],
]).expand(1,3,3,3)/6
k
tensor([[[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]],
[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]],
[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]]]])
k.shape
torch.Size([1, 3, 3, 3])
t = data.valid_ds[0][0].data; t.shape
torch.Size([3, 352, 352])
t[None].shape
torch.Size([1, 3, 352, 352])
edge = F.conv2d(t[None], k)
show_image(edge[0], figsize=(5,5));
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gG9jReQG-1585460002755)(output_30_0.png)]](https://i-blog.csdnimg.cn/blog_migrate/11787c891bcbbfb535cb4a2c5e49da0b.png)
data.c
37
learn.model
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(1): Sequential(
(0): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=1)
(mp): AdaptiveMaxPool2d(output_size=1)
)
(1): Flatten()
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.25, inplace=False)
(4): Linear(in_features=1024, out_features=512, bias=True)
(5): ReLU(inplace=True)
(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
(8): Linear(in_features=512, out_features=37, bias=True)
(9): BatchNorm1d(37, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
print(learn.summary())
Sequential
======================================================================
Layer (type) Output Shape Param # Trainable
======================================================================
Conv2d [64, 176, 176] 9,408 False
______________________________________________________________________
BatchNorm2d [64, 176, 176] 128 True
______________________________________________________________________
ReLU [64, 176, 176] 0 False
______________________________________________________________________
MaxPool2d [64, 88, 88] 0 False
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
ReLU [64, 88, 88] 0 False
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
ReLU [64, 88, 88] 0 False
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
ReLU [64, 88, 88] 0 False
______________________________________________________________________
Conv2d [64, 88, 88] 36,864 False
______________________________________________________________________
BatchNorm2d [64, 88, 88] 128 True
______________________________________________________________________
Conv2d [128, 44, 44] 73,728 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
ReLU [128, 44, 44] 0 False
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
Conv2d [128, 44, 44] 8,192 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
ReLU [128, 44, 44] 0 False
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
ReLU [128, 44, 44] 0 False
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
ReLU [128, 44, 44] 0 False
______________________________________________________________________
Conv2d [128, 44, 44] 147,456 False
______________________________________________________________________
BatchNorm2d [128, 44, 44] 256 True
______________________________________________________________________
Conv2d [256, 22, 22] 294,912 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 32,768 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
ReLU [256, 22, 22] 0 False
______________________________________________________________________
Conv2d [256, 22, 22] 589,824 False
______________________________________________________________________
BatchNorm2d [256, 22, 22] 512 True
______________________________________________________________________
Conv2d [512, 11, 11] 1,179,648 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
ReLU [512, 11, 11] 0 False
______________________________________________________________________
Conv2d [512, 11, 11] 2,359,296 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
Conv2d [512, 11, 11] 131,072 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
Conv2d [512, 11, 11] 2,359,296 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
ReLU [512, 11, 11] 0 False
______________________________________________________________________
Conv2d [512, 11, 11] 2,359,296 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
Conv2d [512, 11, 11] 2,359,296 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
ReLU [512, 11, 11] 0 False
______________________________________________________________________
Conv2d [512, 11, 11] 2,359,296 False
______________________________________________________________________
BatchNorm2d [512, 11, 11] 1,024 True
______________________________________________________________________
AdaptiveAvgPool2d [512, 1, 1] 0 False
______________________________________________________________________
AdaptiveMaxPool2d [512, 1, 1] 0 False
______________________________________________________________________
Flatten [1024] 0 False
______________________________________________________________________
BatchNorm1d [1024] 2,048 True
______________________________________________________________________
Dropout [1024] 0 False
______________________________________________________________________
Linear [512] 524,800 True
______________________________________________________________________
ReLU [512] 0 False
______________________________________________________________________
BatchNorm1d [512] 1,024 True
______________________________________________________________________
Dropout [512] 0 False
______________________________________________________________________
Linear [37] 18,981 True
______________________________________________________________________
BatchNorm1d [37] 74 True
______________________________________________________________________
Total params: 21,831,599
Total trainable params: 563,951
Total non-trainable params: 21,267,648
Optimized with 'torch.optim.adam.Adam', betas=(0.9, 0.99)
Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/
Loss function : FlattenedLoss
======================================================================
Callbacks functions applied
生成热力图
m = learn.model.eval();
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
from fastai.callbacks.hooks import *
def hooked_backward(cat=y):
with hook_output(m[0]) as hook_a:
with hook_output(m[0], grad=True) as hook_g:
preds = m(xb)
preds[0,int(cat)].backward()
return hook_a,hook_g
hook_a,hook_g = hooked_backward()
acts = hook_a.stored[0].cpu()
acts.shape
torch.Size([512, 11, 11])
def show_heatmap(hm):
_,ax = plt.subplots()
xb_im.show(ax)
ax.imshow(hm, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma');
avg_acts = acts.mean(0)
avg_acts.shape
torch.Size([11, 11])
show_heatmap(avg_acts)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-88vMrPFJ-1585460002756)(output_43_0.png)]](https://i-blog.csdnimg.cn/blog_migrate/807dd474cfc8848753a3cadc3768f04a.png)