Pytorch:盲点

本文详细介绍了PyTorch使用过程中的一些关键点,包括conda安装PyTorch-GPU,张量对比与转换,数据处理,nn.Embedding,CrossEntropyLoss的用法及其在不同情况下的形状要求,以及mask的使用。还讨论了torch和numpy中repeat函数的区别,多维张量操作,以及DataLoader和TensorDataset的使用技巧。这些内容对于理解和提升PyTorch编程能力至关重要。
摘要由CSDN通过智能技术生成

1. 用conda安装pytorch-gpu时,用这个命令就够了,网上其他人说的都不好使

conda install pytorch cuda92

注意得是清华源的

2. 比较两个行向量或者列向量,以期求得布尔数组时,必须要保证两边的数据类型一样,并且返回的布尔数组类型和比较的两个向量结构保持一致。另外,所有torch.返回的东西,如果要取得里面的值,必须要加.item()

# !user/bin/python
# -*- coding: UTF-8 -*-

import torch

a = torch.arange(16).view(4, 4)
b = torch.argmax(a, dim = 1)
print([round(x.item(), 5) for x in b])

z = torch.tensor([3, 1, 2, 5], dtype = torch.long) # 类型必须保持一致
z = z.view(-1, 1)
b = b.view(-1, 1)
print(b)
print(z)
print(b == z)
# tensor([[ True],
#         [False],
#         [False],
#         [False]])
print(torch.sum(b == z)) # tensor(1)

3. numpy转tensor,其中,ndarray必须是等长的

x = np.array([[1, 2, 3], [4, 5, 6]]) # 正确
# x = np.array([[1, 2, 3], [4, 5]]) # 错误
print(torch.from_numpy(x))

4. unsqueeze (不改变原有数据)

import torch
import numpy as np

x = torch.tensor([[1, 2], [3, 4]])
print(x)
# tensor([[1, 2],
#         [3, 4]])

# 在第0维的地方插入一维
print(x.unsqueeze(0))
# tensor([[[1, 2],
#          [3, 4]]])
print(x.unsqueeze(0).shape) # torch.Size([1, 2, 2])
print(x.unsqueeze(1))
# tensor([[[1, 2]],

#         [[3, 4]]])
print(x.unsqueeze(1).shape) # torch.Size([2, 1, 2])

5. nn.embedding

# !user/bin/python
# -*- coding: UTF-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

# 看看torch中的torch.nn.embedding
# embedding接收两个参数
# 第一个是num_embeddings,它表示词库的大小,则所有词的下标从0 ~ num_embeddings-1
# 第二个是embedding_dim,表示词嵌入维度
# 词嵌入层有了上面这两个必须有的参数,就形成了类,这个类可以有输入和输出
# 输入的数据结构不限,但是数据结构里面每个单元的元素必须指的是下标,即要对应0 ~ num_embeddings-1
# 输出的数据结构和输入一样,只不过将下标换成对应的词嵌入
# 最开始的时候词嵌入的矩阵是随机初始化的,但是作为嵌入层,会不断的学习参数,所以最后训练完成的参数一定是学习完成的
# embedding层还可以接受一个可选参数padding_idx,这个参数指定的维度,但凡输入的时候有这个维度,输出一律填0

# 下面来看一下吧
embedding =
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值