from future import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from net.fcn_parts import *
class FCNet1(nn.Module):
def init(self, channels, classes, drop_rate):
super().init()
self.channels = channels
self.classes = classes
self.drop_rate = drop_rate
self.input_bn = nn.BatchNorm2d(self.channels)
self.conv1_1 = nn.Conv2d(channels, 32, kernel_size=3,padding=1,bias=True)
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3,padding=1,bias=True)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3,padding=1,bias=True)
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3,padding=1,bias=True)
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4_1 = nn.Conv2d(128, 512, kernel_size=3,padding=1,bias=True)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
self.bn4 = nn.BatchNorm2d(512)
self.relu4 = nn.ReLU(inplace=True)
self.drop4 = nn.Dropout(self.drop_rate)
self.conv5 = nn.Conv2d(512, 512, kernel_size=1,padding=0,bias=True)
self.bn5 = nn.BatchNorm2d(512)
self.relu5 = nn.ReLU(inplace=True)
self.drop5 = nn.Dropout(self.drop_rate)
self.conv6 = nn.Conv2d(512, self.classes, kernel_size=1,padding=0,bias=True)
self.conv_t1 = nn.ConvTranspose2d(self.classes, 64, 2, stride=2)
self.bn6 = nn.BatchNorm2d(64)
self.conv_t2 = nn.ConvTranspose2d(64, self.classes, kernel_size=4, stride=4)
def __initialize_weights(self):
print("**" * 10, "Initing head_conv weights", "**" * 10)
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
if m.bias is not None:
m.bias.data.zero_()
print("initing {}".format(m))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
print("initing {}".format(m))
def forward(self, x):
x = self.input_bn(x)
x = self.conv1_1(x)
x = self.conv1_2(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2_1(x)
x = self.conv2_2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
skip = x.clone()
x = self.conv3_1(x)
x = self.conv3_2(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.pool3(x)
x = self.conv4_1(x)
x = self.conv4_2(x)
x = self.conv4_3(x)
x = self.bn4(x)
x = self.relu4(x)
x = self.drop4(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.drop5(x)
x = self.conv6(x)
x = self.conv_t1(x)
x = self.bn6(x)
x = torch.add(x, skip)
out = self.conv_t2(x)
return out
if name == ‘main’:
net = FCNet1(channels=3, classes=2, drop_rate=0.5)
print(net)
net.eval()
val_loss = float('inf')
for j, (image, label) in enumerate(val_loader):
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
# 使用网络参数,输出预测结果
pred = net(image)
# 计算loss
loss1 = criterion(pred, label)
val_loss += loss1
val_loss = val_loss / j
print(f'Epoch {epoch + 1}/{epochs}, "val_loss:", {val_loss.item()}')
net.train()