《实验细节》实现nucleus sampling中的函数解读
详情参考:《论文阅读》THE CURIOUS CASE OF NEURAL TEXT DeGENERATION
torch.cumsum
对之前的torch.tensor进行累加
import torch
x = torch.linspace(0, 5, 6)
print(x)
y = torch.cumsum(x, dim=0)
print(y)
import torch
x = torch.linspace(0, 5, 6).view(2, 3)
print(x)
y = torch.cumsum(x, dim=0)
print(y)
z = torch.cumsum(x, dim=1)
print(z)
当p=3时
torch.sort
import torch
x = torch.linspace(0, 5, 6)
print(x)
sorted, sorted_indices = torch.sort(x, descending=True)
print(sorted, sorted_indices)
.clone()
x = torch.tensor([2, 2, 5, 3, 1, 4],dtype=torch.float64)
x[..., 1:] = x[..., :-1].clone()
x
当对一个tensor同时进行操作时,需要克隆出一个新内存