baseline部分
引入相关库函数,定义rle编解码函数
flatten: 沿纵向(F)平铺(flatten)向量; np.concatenate: 适用于大规模数据拼接,默认axis=0; np.where输出满足条件 (即非0) 元素的坐标;
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import numba, cv2, gc
from tqdm import tqdm_notebook
import matplotlib. pyplot as plt
% matplotlib inline
import warnings
warnings. filterwarnings( 'ignore' )
from tqdm. notebook import tqdm
import albumentations as A
import rasterio
from rasterio. windows import Window
def rle_encode ( im) :
'''
im: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels = im. flatten( order = 'F' )
pixels = np. concatenate( [ [ 0 ] , pixels, [ 0 ] ] )
runs = np. where( pixels[ 1 : ] != pixels[ : - 1 ] ) [ 0 ] + 1
runs[ 1 : : 2 ] -= runs[ : : 2 ]
return ' ' . join( str ( x) for x in runs)
def rle_decode ( mask_rle, shape= ( 512 , 512 ) ) :
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
s = mask_rle. split( )
starts, lengths = [ np. asarray( x, dtype= int ) for x in ( s[ 0 : ] [ : : 2 ] , s[ 1 : ] [ : : 2 ] ) ]
starts -= 1
ends = starts + lengths
img = np. zeros( shape[ 0 ] * shape[ 1 ] , dtype= np. uint8)
for lo, hi in zip ( starts, ends) :
img[ lo: hi] = 1
return img. reshape( shape, order= 'F' )
引入pytorch相关模块
import torch
import torch. nn as nn
import torch. nn. functional as F
import torch. utils. data as D
import torchvision
from torchvision import transforms as T
定义全局变量(epoch, batch size, image size, 是否GPU计算)和数据增广方法的定义
EPOCHES = 20
BATCH_SIZE = 32
IMAGE_SIZE = 256
DEVICE = 'cuda' if torch. cuda. is_available( ) else 'cpu'
trfm = A. Compose( [
A. Resize( IMAGE_SIZE, IMAGE_SIZE) ,
A. HorizontalFlip( p= 0.5 ) ,
A. VerticalFlip( p= 0.5 ) ,
A. RandomRotate90( ) ,
] )
定义dataset类(init, getitem, len)
class TianChiDataset ( D. Dataset) :
def __init__ ( self, paths, rles, transform, test_mode= False ) :
self. paths = paths
self. rles = rles
self. transform = transform
self. test_mode = test_mode
self. len = len ( paths)
self. as_tensor = T. Compose( [
T. ToPILImage( ) ,
T. Resize( IMAGE_SIZE) ,
T. ToTensor( ) ,
T. Normalize( [ 0.625 , 0.448 , 0.688 ] ,
[ 0.131 , 0.177 , 0.101 ] ) ,
] )
def __getitem__ ( self, index) :
img = cv2. imread( self. paths[ index] )
if not self. test_mode:
mask = rle_decode( self. rles[ index] )
augments = self. transform( image= img, mask= mask)
return self. as_tensor( augments[ 'image' ] ) , augments[ 'mask' ] [ None ]
else :
return self. as_tensor( img) , ''
def __len__ ( self) :
"""
Total number of samples in the dataset
"""
return self. len
数据集读取和rle解码
apply函数是pandas里面所有函数中自由度最高的函数 loc函数主要通过行标签 索引行数据;iloc对应行号;
train_mask = pd. read_csv( '数据集/train_mask.csv' , sep= '\t' , names= [ 'name' , 'mask' ] )
train_mask[ 'name' ] = train_mask[ 'name' ] . apply ( lambda x: '数据集/train/' + x)
img = cv2. imread( train_mask[ 'name' ] . iloc[ 0 ] )
mask = rle_decode( train_mask[ 'mask' ] . iloc[ 0 ] )
print ( rle_encode( mask) == train_mask[ 'mask' ] . iloc[ 0 ] )
调用dataset函数
dataset = TianChiDataset(
train_mask[ 'name' ] . values,
train_mask[ 'mask' ] . fillna( '' ) . values,
trfm, False
)
图片绘制
image, mask = dataset[ 0 ]
plt. figure( figsize= ( 16 , 8 ) )
plt. subplot( 121 )
plt. imshow( mask[ 0 ] , cmap= 'gray' )
plt. subplot( 122 )
plt. imshow( image[ 0 ] ) ;
数据集分割为train和valid两部分
valid_idx, train_idx = [ ] , [ ]
for i in range ( len ( dataset) ) :
if i % 7 == 0 :
valid_idx. append( i)
elif i % 7 == 1 :
train_idx. append( i)
train_ds = D. Subset( dataset, train_idx)
valid_ds = D. Subset( dataset, valid_idx)
loader = D. DataLoader(
train_ds, batch_size= BATCH_SIZE, shuffle= True , num_workers= 0 )
vloader = D. DataLoader(
valid_ds, batch_size= BATCH_SIZE, shuffle= False , num_workers= 0 )
修改模型定义get_model和validation函数
def get_model ( ) :
model = torchvision. models. segmentation. fcn_resnet50( True )
model. classifier[ 4 ] = nn. Conv2d( 512 , 1 , kernel_size= ( 1 , 1 ) , stride= ( 1 , 1 ) )
return model
@torch. no_grad( )
def validation ( model, loader, loss_fn) :
losses = [ ]
model. eval ( )
for image, target in loader:
image, target = image. to( DEVICE) , target. float ( ) . to( DEVICE)
output = model( image) [ 'out' ]
loss = loss_fn( output, target)
losses. append( loss. item( ) )
return np. array( losses) . mean( )
设置优化器和损失函数
odel = get_model( )
model. to( DEVICE) ;
optimizer = torch. optim. AdamW( model. parameters( ) ,
lr= 1e - 4 , weight_decay= 1e - 3 )
class SoftDiceLoss ( nn. Module) :
def __init__ ( self, smooth= 1 . , dims= ( - 2 , - 1 ) ) :
super ( SoftDiceLoss, self) . __init__( )
self. smooth = smooth
self. dims = dims
def forward ( self, x, y) :
tp = ( x * y) . sum ( self. dims)
fp = ( x * ( 1 - y) ) . sum ( self. dims)
fn = ( ( 1 - x) * y) . sum ( self. dims)
dc = ( 2 * tp + self. smooth) / ( 2 * tp + fp + fn + self. smooth)
dc = dc. mean( )
return 1 - dc
bce_fn = nn. BCEWithLogitsLoss( )
dice_fn = SoftDiceLoss( )
def loss_fn ( y_pred, y_true) :
bce = bce_fn( y_pred, y_true)
dice = dice_fn( y_pred. sigmoid( ) , y_true)
return 0.8 * bce+ 0.2 * dice
训练循环,打印反馈信息
header = r'''
Train | Valid
Epoch | Loss | Loss | Time, m
'''
raw_line = '{:6d}' + '\u2502{:7.3f}' * 2 + '\u2502{:6.2f}'
print ( header)
EPOCHES = 5
best_loss = 10
for epoch in range ( 1 , EPOCHES+ 1 ) :
losses = [ ]
start_time = time. time( )
model. train( )
for image, target in tqdm_notebook( loader) :
image, target = image. to( DEVICE) , target. float ( ) . to( DEVICE)
optimizer. zero_grad( )
output = model( image) [ 'out' ]
loss = loss_fn( output, target)
loss. backward( )
optimizer. step( )
losses. append( loss. item( ) )
vloss = validation( model, vloader, loss_fn)
print ( raw_line. format ( epoch, np. array( losses) . mean( ) , vloss,
( time. time( ) - start_time) / 60 ** 1 ) )
losses = [ ]
if vloss < best_loss:
best_loss = vloss
torch. save( model. state_dict( ) , 'model_best.pth' )
数据增广和模型读取
trfm = T. Compose( [
T. ToPILImage( ) ,
T. Resize( IMAGE_SIZE) ,
T. ToTensor( ) ,
T. Normalize( [ 0.625 , 0.448 , 0.688 ] ,
[ 0.131 , 0.177 , 0.101 ] ) ,
] )
subm = [ ]
model. load_state_dict( torch. load( "./model_best.pth" ) )
model. eval ( )