torch学习笔记

关于tensor

  1. tensor.data是属性,tensor.detach()是方法,返回的数据与原数据是共享的,一者改变另一者也跟着改变;并且默认都是不跟踪梯度,但是tensor.data这样会有风险,因为可能在某个地方数据改变了而我不知道,而detach()设置跟踪梯度之后可以规避风险;我们还可以使用类似NumPy的索引操作来访问Tensor的一部分,需要注意的是:索引出来的结果与原数据共享内存,也即修改一个,另一个会跟着修改
  2. tensor默认不追踪梯度
  3. tensor.clone()函数返回一个克隆,数据不共享但是梯度会叠加,除非设置梯度为false
  4. 要完全隔绝两个tensor,一般用torch.clone().detach()
  5. 用tensor.data是为了阻断梯度传播!!!注意不要改变值即可
  6. data.item()函数用于将tensor转换成一个具体的数字
  7. torch.clamp()与torch.clamp_(), pytorch中,一般来说如果对tensor的一个函数后加上了下划线,则表明这是一个in-place类型。in-place类型是指,当在一个tensor上操作了之后,是直接修改了这个tensor,而不是返回一个新的tensor并不修改旧的tensor
a.requires_grad_(True) # 使用in-place的方式改变tensor是否计算梯度
a.grad # tensor的这个属性保存了当前的梯度,使用a.grad.data.zero_()进行梯度的清空
  1. torch.view()改变了tensor的观察角度,数据并未拷贝,改变之后仍然会变;但是二者的内存地址不一致
y=x.view(-1,5)
x_cp=x.clone().view(15)
  1. tensor与numpy互换
# 1. tensor-->numpy 共享内存 
a=torch.ones(5)
b=a.numpy()

# 2. numpy-->tensor 共享内存
a=np.ones(5)
b=torch.from_numpy(a)

# 3. numpy-->tensor 不共享内存
c=torch.tensor(a) # 直接进行数据拷贝,这里是小写的tensor,dtype取决于传入数据类型

# 4. numpy-->tensor 不共享内存
d=torch.Tensor(a) # 直接进行数据拷贝,区别在于dtype统一格式为float32
  1. tensor 数据初始化
from torch.nn import init

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0)  # 也可以直接修改bias的data: net[0].bias.data.fill_(0)
  1. 优化器的参数
optimizer =optim.SGD([
                # 如果对某个参数不指定学习率,就使用最外层的默认学习率
                {'params': net.subnet1.parameters()}, # lr=0.03
                {'params': net.subnet2.parameters(), 'lr': 0.01}
            ], lr=0.03)
# 调整学习率
for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.1 # 学习率为之前的0.1倍
  1. 图片数据格式
    普通图片的数据格式是 H ∗ W ∗ C H*W*C HWC,torch支持的图片数据格式是 C ∗ H ∗ W C*H*W CHW,所以一般数据tensor为 ( b a t c h s i z e , C , H , W ) (batchsize,C,H,W) (batchsize,C,H,W),使用如下
pos=torchvision.transforms.toTensor(pos)
//使用torch.utils.data.DataLoader,可以多线程读取数据,因为数据读取可能是瓶颈
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=is_shuffle, num_workers=num_worker, pin_memory=True)
//其中dataset是一个对象,实现def __getitem__(self,index)、def __len__(self)方法,可迭代对象,相加相当于list扩展,getitem函数返回的是个列表,如return img,pos
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值