【简单的Pytorch回归模型案例】CNN去除随机噪声--修复2d高斯分布【pytorch demo】

 

一、这是个Pytorch学习案例,可以根据这个案例写自己的模型

二、代码

1、导入相关模块

import torch 
from torch import nn
import torchvision
import numpy as np
import cv2
%matplotlib inline
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset
import random
import copy
import torch.optim as optim

2、定义网络模型,这是一个回归模型,用于过滤高斯分布的噪声,从而复原分布,这个模型定义的比较简单

class myNet(nn.Module):
    def __init__(self,):
        super(myNet, self).__init__()
        self.conv1=nn.Conv2d(1,3,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(3,1,kernel_size=3,padding=1)
        self.relu=nn.ReLU(inplace=True)
    def forward(self,x):
        x=self.conv1(x)
        x=self.relu(x)
        x=self.conv2(x)
        
        return x

3、数据生成与测试

class Datagen(Dataset):
    def __init__(self,size=12,transform=None,sigma=3):
        self.size=size
        self.transform=transform
        self.db=[]
        self.sigma=sigma
        for i in range(10):
            x=np.arange(0,self.size,1,np.float32)
            y=x[:,np.newaxis]
            #template=np.zeros((15,15))
            template=np.exp(-((x-random.randint(0,self.size))**2+(y-random.randint(0,self.size))**2)/(2*self.sigma))
            data_noisy = template + 0.2*np.random.normal(size=template.shape)
            #self.db.append([data_noisy,np.exp(-((x-self.size//2)**2+(y-self.size//2)**2)/(2*self.sigma))])
            self.db.append([data_noisy,template[None,:,:]])
            
    def __len__(self,):
        return len(self.db)
    
    def __getitem__(self, idx):
        db_rec = copy.deepcopy(self.db[idx])
        db_rec[0]=self.transform(db_rec[0]).float()
        #data=torch.from_numpy(db_rec)
        
        return db_rec
    
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
#这是pytorch自带的数据转换为tensor的函数,这个库中还包含了对于图像的数据增广函数,很方便
#数据可视化
trainData[1][0].size()
plt.imshow(trainData[2][0][0,:])
plt.show()

4、定义loss以及数据初始化、因为是回归模型,通常会采用l2作为loss

#定义loss
criterion = nn.MSELoss(size_average=True).cuda()

#训练数据初始化
trainData=Datagen(size=166,transform=transform,sigma=15)
train_loader = torch.utils.data.DataLoader(
        trainData,
        batch_size=1,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

5、定义优化器以及初始化网络模型

#网络初始化
net=myNet()
model=net.cuda()
#定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001 )

6、训练迭代器用for循环即可,这里两层循环分别是,epoch和iteration

#开始迭代训练
epoch=50
for i in range(epoch):
    sum_loss=0
    for i, data in enumerate(train_loader):
        input_data,target=data
        
        input_data=input_data.cuda(non_blocking=True)
        
        target=target.cuda(non_blocking=True)
    
    
        output = model(input_data)
    
    
    
        loss = criterion(output, target)
    
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sum_loss+=loss.item()
    msg='Loss:{loss:.4f}'.format(loss=sum_loss/10.)
    print(msg)

7、模型的保存

torch.save(model.state_dict(), PATH)

8、编辑测试案例

#写个测试例子
testData=Datagen(size=166,transform=transform,sigma=15)
test_loader = torch.utils.data.DataLoader(
        testData,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
model.eval()
for i,tData in enumerate(train_loader):
    if i >=1:
        break
    test_input=tData[0].cuda(non_blocking=True)
    test_target=tData[1].cuda(non_blocking=True)
    pred_test=model(test_input)
    
    plt.imshow(test_input[0,0,:,:].cpu().detach().numpy())
    plt.show()
    
    plt.imshow(pred_test[0,0,:,:].cpu().detach().numpy())
    plt.show()
    
    plt.imshow(test_target[0,0,:,:].cpu().detach().numpy())
    plt.show()
    

  • 0
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值