torch.multinomial()理解

本文详细解析了PyTorch中的multinomial函数,介绍了如何使用此函数进行有放回或无放回的采样,并通过实例展示了不同参数设置下的采样结果。
部署运行你感兴趣的模型镜像

torch.multinomial(input, num_samples,replacement=False, out=None) → LongTensor

作用是对input的每一行做n_samples次取值,输出的张量是每一次取值时input张量对应行的下标。

输入是一个input张量,一个取样数量,和一个布尔值replacement。

input张量可以看成一个权重张量,每一个元素代表其在该行中的权重。如果有元素为0,那么在其他不为0的元素

被取干净之前,这个元素是不会被取到的。

n_samples是每一行的取值次数,该值不能大于每一样的元素数,否则会报错。

replacement指的是取样时是否是有放回的取样,True是有放回,False无放回。

看官方给的例子:
>>> weights = torch.Tensor([0, 10, 3, 0]) # create a Tensor of weights
>>> torch.multinomial(weights, 4)

 1
 2
 0
 0
[torch.LongTensor of size 4]

>>> torch.multinomial(weights, 4, replacement=True)

 1
 2
 1
 2
[torch.LongTensor of size 4]

输入是[0,10,3,0],也就是说第0个元素和第3个元素权重都是0,在其他元素被取完之前是不会被取到的。

所以第一个multinomial取4次,可以试试重复运行这条命令,发现只会有2种结果:[1 2 0 0]以及[2 1 0 0],以[1 2 0 0]这种情况居多。这其实很好理解,第1个元素权重比第2个元素权重要大,所以先取第1个元素的概率就会大。在第1和2个元素取完之后,剩下了2个没有权重的元素,它们才会被取到。但实际上权重为0的元素被取到时也不会显示正确的下标,关于0的下标问题我还没有想到很合理的解释,先行略过。

而第二个multinomial取4次,发现就只会出现1和2这两个元素了。这是因为replacement为真,所以有放回,就永远也不会取到权重为0的元素了。

再试试输入二维张量,则返回的也会成为一个二维张量,行数为输入的行数,列数为n_samples,即每一行都取了n_samples次,取法和一维张量相同。

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True): input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1) #bacth_size=1, seq_len长度是多少都可以 (1,5) hidden = None text_generated = [] #用来保存生成的文本 model.eval() pbar = tqdm(range(max_len)) # 进度条 print(start_string, end="") # no_grad是一个上下文管理器,用于指定在其中的代码块中不需要计算梯度。在这个区域内,不会记录梯度信息,用于在生成文本时不影响模型权重。 with torch.no_grad(): for i in pbar:#控制进度条 logits, hidden = model(input_eval, hidden=hidden) # 温度采样,较高的温度会增加预测结果的多样性,较低的温度则更加保守。 #取-1的目的是只要最后,拼到原有的输入上 logits = logits[0, -1, :] / temperature #logits变为1维的 # using multinomial to sampling probs = F.softmax(logits, dim=-1) #算为概率分布 idx = torch.multinomial(probs, 1).item() #从概率分布中抽取一个样本,取概率较大的那些 input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1) #把idx转为tensor text_generated.append(idx) if stream: print(idx2char[idx], end="", flush=True) return "".join([idx2char[i] for i in text_generated]) # load checkpoints model.load_state_dict(torch.load("checkpoints/text_generation/best.ckpt", weights_only=True,map_location="cpu")) start_string = "All: " #这里就是开头,什么都可以 res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)这段代码有什么用
03-11
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值