PyTorch
文章平均质量分 53
vwenyu-L
这个作者很懒,什么都没留下…
展开
-
pytorch-torchvision.models
1.vgg16 = models.vgg16(pretrained=True)原创 2017-05-29 16:27:37 · 3547 阅读 · 0 评论 -
pytorch-tensorboardX
from tensorboardX import SummaryWriterimport torchvision.utils as vutilswriter = SummaryWriter()for i, (x, y) in enumerate(dataloader): xxx writer.add_scalar('/train/loss', loss) wr...原创 2018-03-12 20:27:16 · 679 阅读 · 0 评论 -
pytorch-basic tutorial
CONTACT WEIBOREFERENCESTENSORCLASS MODULEDESIGN MODULECONTAINERMULTI GPUDATASETSAVE RESTOREOPTIMIZERPARAM INITERRORSTENSORBOARDXPROFILER''' by L'''import torch import...原创 2017-12-06 17:39:19 · 528 阅读 · 0 评论 -
pytorch-fineturn the network and adjust learning rate
1. ignored_params = list(map(id, model.fc.parameters()))base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())optimizer = torch.optim.SGD([ {'params': base_params},...原创 2017-05-29 20:09:56 · 1769 阅读 · 0 评论 -
pytorch-multi-gpu
1. nn.DataParallelmodel = nn.DataParallel(model.cuda(1), device_ids=[1,2,3,4,5])criteria = nn.Loss() # i. .cuda(1) 20G-21G ii. cuda() 18.5G-12.7G iii. nothing 16.5G-12.7G. these all use almost same...原创 2017-06-30 10:46:19 · 10756 阅读 · 1 评论 -
pytorch-custom dataset
1. class CostumDataset(data.Dataset):__init__(self)__len__(self)__getitem__(self, index)----------------------------------原创 2017-06-02 11:18:24 · 1757 阅读 · 1 评论 -
pytorch-torchvision transforms
1.import torchimport torchvision as visionrc = vision.transforms.RandomCrop([224, 224])toPIL = vision.transforms.ToPILImage()toTensor = vision.transforms.ToTensor()input = torch.randn(3, 256,原创 2017-06-02 10:17:05 · 2481 阅读 · 0 评论 -
pytorch-design NEW Function and Module
-------------------------1. https://www.qcloud.com/community/article/831497原创 2017-06-12 22:31:25 · 504 阅读 · 0 评论 -
pytorch-errors
0.RuntimeError: save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this conditionWhen custom Funciton & Module, and the module need backward, the input should ...原创 2017-06-12 10:50:05 · 10272 阅读 · 5 评论 -
pytorch-hook
1.register_forward_hook(hook)2. register_backward_hook(hook)--------------------------reference-----------------------------原创 2017-06-03 21:42:29 · 1383 阅读 · 0 评论 -
pytorch-loss function
-----------------reference---------------1.http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral原创 2017-06-11 15:28:34 · 716 阅读 · 0 评论 -
pytorch-obtain feature maps from network
1.import torch import torch.nn as nnimport torchvision.models as modelsfrom torch.autograd import Variableimport timeclass toyNet(nn.Module): def __init__(self, pretrained_model, layers原创 2017-05-29 21:04:35 · 1616 阅读 · 0 评论 -
pytorch-containers
In [90]: class MM(nn.Module): ...: def __init__(self): ...: super(MM, self).__init__() ...: #1. ...: #self.dicx = OrderedDict() .原创 2017-06-22 11:15:54 · 552 阅读 · 0 评论 -
pytorch-tensor data type
1. torch.FloatTensor(2,3).normal_(0,1)torch.LongTensor(2,3).random(0,4)torch.from_numpy(arr).float() data.float()model.float()------------------------------------原创 2017-06-21 12:10:58 · 2658 阅读 · 0 评论 -
pytorch-save and load models
1.torch.save(model_name.static_dict(), name_to_save)net = ModelClass(args)net.load_static_dict(torch.load(PATH))2.torch.save(model, PATH)net = torch.load(PATH)原创 2017-05-29 17:12:34 · 5590 阅读 · 0 评论 -
pytorch-parameter initialization
torch.nn.initweight.data.fill_(1)bias.data.fill_(0)weight.data.uniform_(-stdv, stdv)1. params = list(net.parameters())2. conv2params = list(net.conv2.parameters())kernels conv2params[0]bias conv2para...原创 2017-05-29 16:53:18 · 5233 阅读 · 0 评论 -
pytorch-class nn.Module
1. .children()return a generator objectlen(list(net.children()))nn.Sequential(*list(net.children())[:10])where. net is an object of Model2. register_backward_hook()原创 2017-05-29 16:18:36 · 3443 阅读 · 0 评论 -
pytorch-distributed traning
1. 单机多卡数据并行```model = xxx losses = torch.nn.parallel.data_parallel(model, inputs=(), device_ids=[], dim=x) # functional style```2. pytorch-1.0 distributed2.1 单机多卡 2.2 多级多卡 ----...原创 2019-01-28 15:25:48 · 989 阅读 · 2 评论