Pytorch框架之tensor.equal和tensor.eq--判断相等

本文介绍了PyTorch中Tensor的eq()和equal()函数的使用,通过实例展示了它们在行人重识别Relation-Aware Global Attention模型中的应用。eq()函数用于逐元素比较张量是否相等,返回一个布尔张量,而equal()函数则判断整个张量是否相等。此外,还探讨了这两个方法在比较张量时的区别和用途。
摘要由CSDN通过智能技术生成

行人重识别Relation-Aware Global Attention的部分代码。

	N = dist_mat.size(0)  # 8

	is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
	#print(is_pos.shape)
	is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
	#print(is_neg.shape)

一、测试eq()函数:

import torch

lab = torch.tensor([389, 389, 389, 389, 628, 628, 628, 628])
print(lab.expand(8, 8))
print("****")
print(lab.expand(8, 8).t())
print("****")
print(lab.expand(8, 8).eq(lab.expand(8, 8).t()))

输出结果:

tensor([[389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628],
        [389, 389, 389, 389, 628, 628, 628, 628]])
****
tensor([[389, 389, 389, 389, 389, 389, 389, 389],
        [389, 389, 389, 389, 389, 389, 389, 389],
        [389, 389, 389, 389, 389, 389, 389, 389],
        [389, 389, 389, 389, 389, 389, 389, 389],
        [628, 628, 628, 628, 628, 628, 628, 628],
        [628, 628, 628, 628, 628, 628, 628, 628],
        [628, 628, 628, 628, 628, 628, 628, 628],
        [628, 628, 628, 628, 628, 628, 628, 628]])
****
tensor([[ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True]])

二、扩展:

扩展内容来自一篇博客:https://blog.csdn.net/weixin_44604887/article/details/109385651.

2.1、equal – 张量比较

  • 原型:equal(other)
  • 比较两个张量是否相等–相等返回:True; 否则返回:False
'''
		tensor调用equal方法与torch.equal是一致的
		都是比较两个张量是否相等
'''
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.equal(y))
print(x.equal(y) == torch.equal(x, y))

在这里插入图片描述

2.2、eq – 逐元素判断

  • 原型:eq(other)
  • 比较两个张量tensor中,每一个对应位置上元素是否相等–对应位置相等,就返回一个True;否则返回一个False.
'''
		逐元素进行判断是否相等
'''
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.eq(y))

在这里插入图片描述

2.3、eq_ – 将判断结果返回并替换原tensor

  • 原型:eq_(other)
  • 等价于tensor = tensor.eq(other); 即:将比较后的结果替换原张量的值
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 1, 3])
print(x.eq_(y))
print(x)

在这里插入图片描述

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
由于没有提供具体的tensorflow1代码,下面是一个简单的示例,将其转换为使用pytorch框架的代码: Tensorflow1代码: ``` import tensorflow as tf # 定义输入和输出 x = tf.placeholder(tf.float32, shape=[None, 784]) y_ = tf.placeholder(tf.float32, shape=[None, 10]) # 定义模型 W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) + b) # 定义损失函数和优化器 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 训练模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # 测试模型 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) ``` Pytorch代码: ``` import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # 定义模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 10) def forward(self, x): x = self.fc1(x) return nn.functional.softmax(x, dim=1) model = Net() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.5) # 加载数据 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False) # 训练模型 for epoch in range(10): for batch_idx, (data, target) in enumerate(train_loader): data = data.view(-1, 784) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 测试模型 correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data = data.view(-1, 784) output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() print('Accuracy: %f' % (correct/total)) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值