这个是pytorch出来没多久的时候写的了,现在看是非常傻逼的方法,羞耻感十足。
推荐学习项目【pix2pix】的代码,优雅!
–作者 2018.1.30
U-Net 的实现现在github上非常多了吧!用dense-net大概也随随便便吊打了吧!不要用我这个啦~批判性参考一下pytorch咋用还差不多~!
–作者 2018.4.09
大概不支持pytorch 0.4以及以上版本
pytorch是一个很好用的工具,作为一个python的深度学习包,其接口调用起来很方便,具备自动求导功能,适合快速实现构思,且代码可读性强,比如前阵子的WGAN1
好了回到Unet。
原文 arXiv:1505.04597 [cs.CV]
主页 U-Net: Convolutional Networks for Biomedical Image Segmentation
该文章实现了生物图像分割的一个网络,2015年的模型,好像是该领域的冠军。模型长得像个巨大的U,故取名Unet,之前很火的动漫线稿自动上色2就是用的这个模型。当然,该模型也许比不上现在的各种生成式模型了,不过拿来在pytorch里练练手,当做boundary提取,还是可以的。注意这个网络的输出size与输入size不一致,所以应用起来需要额外的处理。
模型长这个鬼样:
参考pytorch的tutorial代码,实现如下:
#unet.py:
from __future__ import division
import torch.nn as nn
import torch.nn.functional as F
import torch
from numpy.linalg import svd
from numpy.random import normal
from math import sqrt
class UNet(nn.Module):
def __init__(self,colordim =1):
super(UNet, self).__init__()
self.conv1_1 = nn.Conv2d(colordim, 64, 3) # input of (n,n,1), output of (n-2,n-2,64)
self.conv1_2 = nn.Conv2d(64, 64, 3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2_1 = nn.Conv2d(64, 128, 3)
self.conv2_2 = nn.Conv2d(128, 128, 3)
self.bn2 = nn.BatchNorm2d(128)
self.conv3_1 = nn.Conv2d(128, 256, 3)
self.conv3_2 = nn.Conv2d(256, 256, 3)
self.bn3 = nn.BatchNorm2d(256)
self.conv4_1 = nn.Conv2d(256, 512, 3)
self.conv4_2 = nn.Conv2d(512, 512, 3)
self.bn4 = nn.BatchNorm2d(512)
self.conv5_1 = nn.Conv2d(512, 1024, 3)
self.conv5_2 = nn.Conv2d(1024, 1024, 3)
self.upconv5 = nn.Conv2d(1024, 512, 1)
self.bn5 = nn.BatchNorm2d(512)
self.bn5_out = nn.BatchNorm2d(1024)
self.conv6_1 = nn.Conv2d(1024, 512, 3)
self.conv6_2 = nn.Conv2d(512, 512, 3)
self.upconv6 = nn.Conv2d(512, 256, 1)
self.bn6 = nn.BatchNorm2d(256)
self.bn6_out = nn.BatchNorm2d(512)
self.conv7_1 = nn.Conv2d(512, 256, 3)
self.conv7_2 = nn.Conv2d(256, 256, 3)
self.upconv7 = nn.Conv2d(256, 128, 1)
self.bn7 = nn.BatchNorm2d(128)
self.bn7_out = nn.BatchNorm2d(256)
self.conv8_1 = nn.Conv2d(256, 128, 3)
self.conv8_2 = nn.Conv2d(128, 128, 3)
self.upconv8 = nn.Conv2d(128, 64, 1)
self.bn8 = nn.BatchNorm2d(64)
self.bn8_out = nn.BatchNorm2d(128)
self.conv9_1 = nn.Conv2d(128, 64, 3)
self.conv9_2 = nn.Conv2d(64, 64, 3)
self.conv9_3 = nn.Conv2d(64, colordim, 1)
self.bn9 = n