2020-09-21

from future import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from net.fcn_parts import *

class FCNet1(nn.Module):
def init(self, channels, classes, drop_rate):
super().init()
self.channels = channels
self.classes = classes
self.drop_rate = drop_rate

    self.input_bn = nn.BatchNorm2d(self.channels)
    self.conv1_1 = nn.Conv2d(channels, 32, kernel_size=3,padding=1,bias=True)
    self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3,padding=1,bias=True)
    self.bn1 = nn.BatchNorm2d(32)
    self.relu1 = nn.ReLU(inplace=True)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3,padding=1,bias=True)
    self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
    self.bn2 = nn.BatchNorm2d(64)
    self.relu2 = nn.ReLU(inplace=True)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3,padding=1,bias=True)
    self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True)
    self.bn3 = nn.BatchNorm2d(128)
    self.relu3 = nn.ReLU(inplace=True)
    self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv4_1 = nn.Conv2d(128, 512, kernel_size=3,padding=1,bias=True)
    self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
    self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
    self.bn4 = nn.BatchNorm2d(512)
    self.relu4 = nn.ReLU(inplace=True)
    self.drop4 = nn.Dropout(self.drop_rate)

    self.conv5 = nn.Conv2d(512, 512, kernel_size=1,padding=0,bias=True)
    self.bn5 = nn.BatchNorm2d(512)
    self.relu5 = nn.ReLU(inplace=True)
    self.drop5 = nn.Dropout(self.drop_rate)

    self.conv6 = nn.Conv2d(512, self.classes, kernel_size=1,padding=0,bias=True)
    self.conv_t1 = nn.ConvTranspose2d(self.classes, 64, 2, stride=2)

    self.bn6 = nn.BatchNorm2d(64)
    self.conv_t2 = nn.ConvTranspose2d(64, self.classes, kernel_size=4, stride=4)


def __initialize_weights(self):
    print("**" * 10, "Initing head_conv weights", "**" * 10)

    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            if m.bias is not None:
                m.bias.data.zero_()

            print("initing {}".format(m))
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

            print("initing {}".format(m))

def forward(self, x):
    x = self.input_bn(x)
    x = self.conv1_1(x)
    x = self.conv1_2(x)
    x = self.bn1(x)
    x = self.relu1(x)
    x = self.pool1(x)

    x = self.conv2_1(x)
    x = self.conv2_2(x)
    x = self.bn2(x)
    x = self.relu2(x)
    x = self.pool2(x)
    skip = x.clone()

    x = self.conv3_1(x)
    x = self.conv3_2(x)
    x = self.bn3(x)
    x = self.relu3(x)
    x = self.pool3(x)

    x = self.conv4_1(x)
    x = self.conv4_2(x)
    x = self.conv4_3(x)
    x = self.bn4(x)
    x = self.relu4(x)
    x = self.drop4(x)

    x = self.conv5(x)
    x = self.bn5(x)
    x = self.relu5(x)
    x = self.drop5(x)

    x = self.conv6(x)
    x = self.conv_t1(x)
    x = self.bn6(x)

    x = torch.add(x, skip)
    out = self.conv_t2(x)

    return out

if name == ‘main’:
net = FCNet1(channels=3, classes=2, drop_rate=0.5)
print(net)

net.eval()
val_loss = float('inf')
for j, (image, label) in enumerate(val_loader):
    image = image.to(device=device, dtype=torch.float32)
    label = label.to(device=device, dtype=torch.float32)
    # 使用网络参数,输出预测结果
    pred = net(image)
    # 计算loss
    loss1 = criterion(pred, label)
    val_loss += loss1

val_loss = val_loss / j
print(f'Epoch {epoch + 1}/{epochs}, "val_loss:", {val_loss.item()}')
net.train()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值