pytorch常见问题汇总

一. pytorch 指定GPU (转自http://www.cnblogs.com/darkknightzh/p/6836568.html

PyTorch默认使用从0开始的GPU,如果GPU0正在运行程序,需要指定其他GPU。
有如下两种方法来指定需要使用的GPU。

1.类似tensorflow指定GPU的方式,使用CUDA_VISIBLE_DEVICES
1.1 直接终端中设定:

CUDA_VISIBLE_DEVICES=1 python my_script.py

另一种使用方法:
$ export CUDA_VISIBLE_DEVICES=1
$ python train.py

1.2 python代码中设定:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

见网址:http://www.cnblogs.com/darkknightzh/p/6591923.html

2.使用函数 set_device

import torch
torch.cuda.set_device(id)

该函数见 pytorch-master\torch\cuda_init_.py。

不过官方建议使用CUDA_VISIBLE_DEVICES,不建议使用 set_device 函数。

二. Torch Numpy Variable PILimage 之间转换
1. Tensor 与 numpy之间转换

# tensor to numpy
a = torch.FloatTensor(3,3)
a = a.numpy()
# numpy to tensor
a = np.ones(5)
a = torch.from_numpy(a)

2. Variable 转 numpy,Tensor [pytorch0.4 之后不再使用variable]

# to tensor
a = Variable(torch.FloatTensor(3,3))
a = a.data
# to numpy
a = Variable(torch.FloatTensor(3,3))
a = a.data.numpy()

3. numpy, tensor 转 Variable

# tensor to variable
a = Variable(tensor)
# numpy to variable
a = np.ones(5)
a = Variable(torch.from_numpy(a))

5. tensor 与 Pil image 之间转换

#Tensor to PIL Image
toPIL = transforms.ToPILImage()
image = tensor.cpu()
image = image.squeeze(0)		#保存为灰度图
image = toPIL(image)
#PIL Image to Tensor
toTensor = transfroms.ToTensor()
image = Image.open(image_name).convert('RGB')
image = toTensor(image)

6. numpy 与 Pil image 之间转换

# Pil image to numpy
image = Image.open(image_name).convert('RGB')
np.array(image)
# numpy to Pil image
Image.fromarray(numpy.ndarray)

#numpy.ndarray 需要转换成np.uint8型:numpy.astype(np.uint8),像素值[0,255]。
#同时灰度图像保证numpy.shape为(H,W),不能出现channels,可能需要np.squeeze()。彩色图象保证numpy.shape为(H,W,3),可能需要函数 tranpose()

三. 打印网络参数
1.

import utils
VGG = networks.VGG19('vgg19.pth', feature_mode=True)
VGG.to(device)
VGG.eval()
print('---------- Networks initialized -------------')
utils.print_network(VGG)
print('-----------------------------------------------')

打印出来的是网络卷积层、池化层、激活层等内的参数信息,同时会打印网络总参数,如下,VGG19 分为两个块:features和classifier,使用时可直接使用其名字,self.features/self.classifier:

2.

for name, param in VGG.named_parameters():
	print(name, '      ', param.size())

打印的是模块名字.序号.权重名(注意此处不回打印relu,pool不需要back的层,打印结果:
在这里插入图片描述
如果直接打印param, 即 print(name,param), 打印结果:打印出来的详细参数
在这里插入图片描述
3.

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值