python torch库_Python revtorch包_程序模块 - PyPI - Python中文网

旋转手电筒

用pytorch建立(部分)可逆神经网络的框架

本文介绍并解释了revtorch,

在MICCAI 2019接受陈述。

如果您发现此代码对您的研究有帮助,请引用以下文章:@article{PartiallyRevUnet2019Bruegger,

author={Br{\"u}gger, Robin and Baumgartner, Christian F.

and Konukoglu, Ender},

title={A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation},

journal={arXiv:1906.06148},

year={2019},

安装

使用pip安装revtorch:$ pip install revtorch

revtorch需要pytorch。但是,pytorch不包含在依赖项中,因为所需的pytorch版本依赖于您的系统。请按照PyTorch website上的说明安装pytorch。

用法

这个例子展示了如何使用revtorch框架。importtorchimporttorchvisionimporttorchvision.transformsastransformsimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimimportrevtorchasrvdeftrain():trainset=torchvision.datasets.CIFAR10(root="./data",train=True,download=True,transform=transforms.ToTensor())trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True)net=PartiallyReversibleNet()criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(net.parameters())forepochinrange(2):running_loss=0.0fori,datainenumerate(trainloader,0):inputs,labels=dataoptimizer.zero_grad()outputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()#logging stuffrunning_loss+=loss.item()LOG_INTERVAL=200ifi%LOG_INTERVAL==(LOG_INTERVAL-1):# print every 2000 mini-batchesprint('[%d,%5d] loss:%.3f'%(epoch+1,i+1,running_loss/LOG_INTERVAL))running_loss=0.0classPartiallyReversibleNet(nn.Module):def__init__(self):super(PartiallyReversibleNet,self).__init__()#initial non-reversible convolution to get to 32 channelsself.conv1=nn.Conv2d(3,32,3)#construct reversible sequencce with 4 reversible blocksblocks=[]foriinrange(4):#f and g must both be a nn.Module whos output has the same shape as its inputf_func=nn.Sequential(nn.ReLU(),nn.Conv2d(16,16,3,padding=1))g_func=nn.Sequential(nn.ReLU(),nn.Conv2d(16,16,3,padding=1))#we construct a reversible block with our F and G functionsblocks.append(rv.ReversibleBlock(f_func,g_func))#pack all reversible blocks into a reversible sequenceself.sequence=rv.ReversibleSequence(nn.ModuleList(blocks))#non-reversible convolution to get to 10 channels (one for each label)self.conv2=nn.Conv2d(32,10,3)defforward(self,x):x=self.conv1(x)#the reversible sequence can be used like any other nn.Module. Memory-saving backpropagation is used automaticallyx=self.sequence(x)x=self.conv2(F.relu(x))x=F.avg_pool2d(x,(x.shape[2],x.shape[3]))x=x.view(x.shape[0],x.shape[1])returnxif__name__=="__main__":train()

python版本

使用python 3.6和pytorch 1.1.0进行测试。应该适用于任何版本的python 3。

欢迎加入QQ群-->: 979659372

推荐PyPI第三方库

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值