一、这是个Pytorch学习案例,可以根据这个案例写自己的模型
二、代码
1、导入相关模块
import torch
from torch import nn
import torchvision
import numpy as np
import cv2
%matplotlib inline
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import random
import copy
import torch.optim as optim
2、定义网络模型,这是一个回归模型,用于过滤高斯分布的噪声,从而复原分布,这个模型定义的比较简单
class myNet(nn.Module):
def __init__(self,):
super(myNet, self).__init__()
self.conv1=nn.Conv2d(1,3,kernel_size=3,padding=1)
self.conv2=nn.Conv2d(3,1,kernel_size=3,padding=1)
self.relu=nn.ReLU(inplace=True)
def forward(self,x):
x=self.conv1(x)
x=self.relu(x)
x=self.conv2(x)
return x
3、数据生成与测试
class Datagen(Dataset):
def __init__(self,size=12,transform=None,sigma=3):
self.size=size
self.transform=transform
self.db=[]
self.sigma=sigma
for i in range(10):
x=np.arange(0,self.size,1,np.float32)
y=x[:,np.newaxis]
#template=np.zeros((15,15))
template=np.exp(-((x-random.randint(0,self.size))**2+(y-random.randint(0,self.size))**2)/(2*self.sigma))
data_noisy = template + 0.2*np.random.normal(size=template.shape)
#self.db.append([data_noisy,np.exp(-((x-self.size//2)**2+(y-self.size//2)**2)/(2*self.sigma))])
self.db.append([data_noisy,template[None,:,:]])
def __len__(self,):
return len(self.db)
def __getitem__(self, idx):
db_rec = copy.deepcopy(self.db[idx])
db_rec[0]=self.transform(db_rec[0]).float()
#data=torch.from_numpy(db_rec)
return db_rec
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
#这是pytorch自带的数据转换为tensor的函数,这个库中还包含了对于图像的数据增广函数,很方便
#数据可视化
trainData[1][0].size()
plt.imshow(trainData[2][0][0,:])
plt.show()
4、定义loss以及数据初始化、因为是回归模型,通常会采用l2作为loss
#定义loss
criterion = nn.MSELoss(size_average=True).cuda()
#训练数据初始化
trainData=Datagen(size=166,transform=transform,sigma=15)
train_loader = torch.utils.data.DataLoader(
trainData,
batch_size=1,
shuffle=True,
num_workers=2,
pin_memory=True
)
5、定义优化器以及初始化网络模型
#网络初始化
net=myNet()
model=net.cuda()
#定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001 )
6、训练迭代器用for循环即可,这里两层循环分别是,epoch和iteration
#开始迭代训练
epoch=50
for i in range(epoch):
sum_loss=0
for i, data in enumerate(train_loader):
input_data,target=data
input_data=input_data.cuda(non_blocking=True)
target=target.cuda(non_blocking=True)
output = model(input_data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss+=loss.item()
msg='Loss:{loss:.4f}'.format(loss=sum_loss/10.)
print(msg)
7、模型的保存
torch.save(model.state_dict(), PATH)
8、编辑测试案例
#写个测试例子
testData=Datagen(size=166,transform=transform,sigma=15)
test_loader = torch.utils.data.DataLoader(
testData,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True
)
model.eval()
for i,tData in enumerate(train_loader):
if i >=1:
break
test_input=tData[0].cuda(non_blocking=True)
test_target=tData[1].cuda(non_blocking=True)
pred_test=model(test_input)
plt.imshow(test_input[0,0,:,:].cpu().detach().numpy())
plt.show()
plt.imshow(pred_test[0,0,:,:].cpu().detach().numpy())
plt.show()
plt.imshow(test_target[0,0,:,:].cpu().detach().numpy())
plt.show()