学习Pytorch
Pytorch框架学习笔记
rosefunR
每次都多付出一点. 欢迎关注公众号《机器学习与算法之道》
展开
-
pytorch版本查看和升级到相应的版本
1. 版本查看import torchprint(torch.__version__)结果:'1.0.0'2. 升级到特定的版本$ pip install --upgrade torch==1.4.0结果:Installing collected packages: torch Found existing installation: torch 1.0.0 Uninstalling torch-1.0.0: Successfully uninstalled t原创 2020-10-22 16:39:38 · 18718 阅读 · 0 评论 -
Pytorch数据读取加速方法
1. 方法一:使用prefetcherclass data_prefetcher(): def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)原创 2020-08-20 10:38:23 · 22303 阅读 · 4 评论 -
Pytorch 模型查看参数
def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)print("模型的参数量", count_parameters(model))print("model",model)from torchsummary import summaryprint(summary(model, (input_data.shape[1])))...原创 2020-08-18 17:34:40 · 1534 阅读 · 0 评论 -
Pytorch cannot allocate memory 解决内存不足问题
del X_train, y_train, X_valid, y_validgc.collect()torch.cuda.empty_cache()# torch.cuda.clear_memory_allocated() # entirely clear all allocated memory原创 2020-08-08 13:20:45 · 5792 阅读 · 0 评论 -
Pytorch clamp修剪数据
torch.clamp(input, min, max, out=None) → TensorClamp all elements in input into the range [ min, max ] and return a resulting tensor:参考:torch clamp原创 2020-01-02 15:17:24 · 626 阅读 · 0 评论 -
Pytorch probability distributions
1.OneHotCategoricaltorch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)根据给定的概率probs, 创建一个 one-hot 的类别分布.m = OneHotCategorical(torch.tensor([ 0.1, 0...原创 2020-01-02 15:04:38 · 984 阅读 · 0 评论 -
Pytorch argmax
1.argmaxtorch.argmax(input, dim, keepdim=False) → LongTensor返回指定维度的最大值的索引。2.示例b = torch.randn(4, 5)torch.argmax(b, dim=0)Output:# btensor([[-2.2276, 0.9573, -1.9554, 0.8877, 0.7543], ...原创 2020-01-02 14:47:41 · 6350 阅读 · 0 评论 -
Pytorch stop_gradient
1.var.detach()detach():返回一个新的Tensor,并且不需要计算梯度。Returns a new Tensor, detached from the current graph.这里的 var 是指一个tensor。2. var.requires_grad_(requires_grad=False)指定tensor是否需要进行梯度更新。这里的var是指一个te...原创 2020-01-02 14:30:16 · 5200 阅读 · 0 评论 -
Pytorch 实现tf.gather()
1. 实现tf.gather在pytorch中,实现 tf.gather 很简单,只需要使用 select。select(dim, index) → Tensor比如,import numpy as npa = np.array([[1],[2],[3],[4],[5]])b = torch.from_numpy(a)indices = [ 1, 2, 0]b[indices]...原创 2020-01-02 11:39:13 · 2607 阅读 · 2 评论 -
pytorch模型训练
1.模型训练for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels]...原创 2020-01-01 22:32:42 · 282 阅读 · 0 评论 -
numpy实现两层神经网络
1 理论损失函数的反向传播就是层层链式求导,得到每个矩阵或者向量的增量,然后更新。2 实践import numpy as npx = np.array([[0,0,1], [0,1,1],[1,0,1],[1,1,1]])y = np.array([[0,1,1,0]]).Tw0 = 2*np.random.random((3,4))-1w1 = 2*np.random.rando...原创 2019-03-31 15:37:45 · 9553 阅读 · 0 评论 -
浅谈Attention UNet
1 理论其中,g就是解码部分的矩阵,xl是编码(左边)的矩阵,x经过乘于系数(完成Attention)和g一起concat,进入下一层解码。数学公式:2 实践Pytorch Attention Unet:class Attention_block(nn.Module): def __init__(self,F_g,F_l,F_int): super(Atte...原创 2019-03-28 14:33:30 · 23003 阅读 · 16 评论 -
精确快速目标检测接收场块网络(RFBNet)
1.模型使用11,33,55卷积以及33(rate = 1,3,5)的空洞卷积核。各种模型对比图:RFB结构:RFB-Net300:2.实现class BasicRFB(nn.Module): def __init__(self, in_planes, out_planes, stride=1, scale = 0.1,map_reduce=8): s...原创 2019-03-18 19:55:55 · 9570 阅读 · 0 评论 -
Pytorch保存模型saveModel及使用tensorboardX
参考:1 save and load models Kaggle原创 2019-03-11 11:46:27 · 1959 阅读 · 0 评论 -
Pytorch 的损失函数Loss function
0 损失函数损失函数,又叫目标函数,是编译一个神经网络模型必须的两个参数之一。另一个必不可少的参数是优化器。损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习过程中,有多种损失函数可供选择,典型的有距离向量,绝对值向量等。我们先定义两个二维数组,然后用不同的损失函数计算其损失值。import torchfrom torch.autograd import Variableim...原创 2019-03-01 15:22:05 · 28369 阅读 · 3 评论 -
更高效率读取图片的方法
1.方法一用list 来存储,最后转换为array。for fpath in fpaths: img = cv2.imread(fpath, cv2.CV_LOAD_IMAGE_COLOR) image_list.append(img)data = np.vstack(image_list)data = data.reshape(-1, 256, 256, 3)data = d...原创 2019-02-27 13:48:30 · 1173 阅读 · 0 评论 -
Pytorch读取数据Dataset,DataLoader及流式读取文件
简介最近都是看图像里边的语义分割部分内容,比较有趣,同时入门Pytorch。Pytorch的主要特点是基本上所有操作都是用类来进行封装,本身自带很多类,而且你也可以根据官方的类进行修改。1 数据导入数据导入,本来Pytorch就有好几个类进行实现,分别是 DataSet, DataLoader, DataLoaderIter等。以下是我用的一种方法。首先我的数据是存在data_dir里边...原创 2019-02-26 20:27:52 · 8001 阅读 · 0 评论 -
Pytorch入门及常用函数查询
背景适合现在pytorch比较流行,趁有空学学。首先,,pytorch是介入TensorFlow和Keras之间,比较灵活,有不那么难懂。1.常见问题原创 2019-01-26 15:52:38 · 2037 阅读 · 0 评论