机器学习变量--函数使用

本文介绍了PyTorch库中的三个关键函数:torch.gather用于根据索引从张量中选择元素,topK返回最大或最小值及其索引,mean计算张量的平均值。通过示例展示了如何在实际场景中使用这些功能。
摘要由CSDN通过智能技术生成

torch.gather()

import torch

a = torch.tensor([
        [ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

b = torch.tensor(
[[1, 0, 0, 0, 0],
 [0, 0, 1, 0, 0],
 [0, 0, 0, 0, 0]])

c = a.gather(0, b) # dim=0
d = a.gather(1, b) # dim=1
print(c)
print(d)
'''
tensor([[5, 1, 2, 3, 4],
        [0, 1, 7, 3, 4],
        [0, 1, 2, 3, 4]])
tensor([[ 1,  0,  0,  0,  0],
        [ 5,  5,  6,  5,  5],
        [10, 10, 10, 10, 10]])
'''

在这里插入图片描述

topK()

import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader

####################准备一个数组#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)

print(torch.topk(tensor1,k=3,dim=1,largest=True))
print('-'*40)
print(torch.topk(tensor1,k=3,dim=0,largest=True))
'''
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[10,  0,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  7]]))
----------------------------------------
torch.return_types.topk(
values=tensor([[10.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 7.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]]),
indices=tensor([[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0],
        [2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 3]]))
'''

mean()

import torch
x = torch.tensor([
    [[1,2,3,4],[5,6,7,8],[9,10,11,12]],
    [[13,14,15,16],[17,18,19,20],[21,22,23,24]]
]).float() # [[[]], [[]]] torch.size=(2,3,4)

print(x.size())

print(x.mean(dim=0,keepdim=True))   #(1,3,4)
print(x.mean(dim=1,keepdim=True))  #(2,1,4)
print(x.mean(dim=2,keepdim=True)) # (2,3,1)最后一个维度
'''
torch.Size([2, 3, 4])
tensor([[[ 7.,  8.,  9., 10.],
         [11., 12., 13., 14.],
         [15., 16., 17., 18.]]])
tensor([[[ 5.,  6.,  7.,  8.]],

        [[17., 18., 19., 20.]]])
tensor([[[ 2.5000],
         [ 6.5000],
         [10.5000]],

        [[14.5000],
         [18.5000],
         [22.5000]]])
'''
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值