链接:https://www.kaggle.com/leighplt/pytorch-tta-flip-left-right
tta 见过不少了,今天发现一个python的代码技巧记录一下
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
import torchvision
from torchvision import models
import cv2
from pathlib import Path
import glob
#============= tta ===================
#这里的tta是可插拔的,用在训练和预测上都行,下面有使用的方法,一看就很明了,其中这里面的staticmethod返回函数的静态方法,
#该方法不强制要求传递参数,并且无需实例化就可以调用,也可以实例化调用,很灵活。
#实例化调用方法就是 形如:C = TTAFunction() 然后调用时C.tta()这样,不实例化的话可以直接TTAFunction.tta()
class TTAFunction:
"""
Simple TTA function
"""
@staticmethod
def hflip(x):
return x.flip(3)
@staticmethod
def vflip(x):
return x.flip(2)
def tta(self, x):
self.eval()
with torch.no_grad():
result = self.forward(x)
x = self.hflip(x)
result += self.hflip(self.forward(x))
return 0.5*result
#============= model ===================
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class UNet11(TTAFunction, nn.Module): # use our class with TTA function
def __init__(self, num_filters=32):
"""
:param num_classes:
:param num_filters:
"""
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
# Convolutions are from VGG11
self.encoder = models.vgg11().features
# "relu" layer is taken from VGG probably for generality, but it's not clear
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, 1, kernel_size=1, )
def forward(self, x):
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
# Deconvolutions with copies of VGG11 layers of corresponding size
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return torch.sigmoid(self.final(dec1))
def unet11(**kwargs):
model = UNet11(**kwargs)
return model
def get_model():
np.random.seed(717)
torch.cuda.manual_seed(717);
torch.manual_seed(717);
model = unet11()
model.train()
return model.to(device)
#============= use tta for predict===================
model = get_model()
model.load_state_dict(torch.load(model_pth)['state_dict'])
test_dataset = TGSSaltDataset(test_path, test_file_list, is_test = True) #这个函数原来链接里有
all_predictions = []
for image in data.DataLoader(test_dataset, batch_size = 30):
image = image[0].type(torch.FloatTensor).to(device)
y_pred = model.tta(image).cpu().data.numpy() # use tta_flip
all_predictions.append(y_pred)
all_predictions_stacked = np.vstack(all_predictions)[:, 0, :, :]