Pytorch使用总结
设备的确定
os.environ[‘环境变量名称’]=‘环境变量值’
if cuda_used:
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
device = torch.device('cuda')
else:
device = torch.device('cpu')
np.log10()
取以10为底的对数
Python的魔法方法__getitem__
Python的魔法方法__getitem__ 可以让对象实现迭代功能,这样就可以使用for…in… 来迭代该对象了。
class Animal:
def __init__(self, animal_list):
self.animals_name = animal_list
def __getitem__(self, index):
return self.animals_name[index]
animals = Animal(["dog","cat","fish"])
for animal in animals:
print(animal)
dog
cat
fish
DataLoader()
pytorch对数据的传入经常使用DataLoader类,通过给DataLoader传入迭代器和batch_size的大小来实现。实现迭代器可以在数据类中实现__getitem__魔术方法,使其成为一个迭代器。
joblib.dump()与joblib.load()
非常好使,加载模型+保存模型
Counter()
Counter 是 dict 字典的子类,Counter 拥有类似字典的 key 键和 value 值,只不过 Counter 中的键为待计数的元素,而 value 值为对应元素出现的次数 count,为了方便介绍统一使用元素和 count 计数来表示。虽然 Counter 中的 count 表示的是计数,但是 Counter 允许 count 的值为 0 或者负值。
elements()方法
elements()方法返回一个迭代器,可以通过 list 或者其它方法将迭代器中的元素输出,输出的结果为对应出现次数的元素。
在torch中实现sequence_mask()功能
我们知道在tensorflow中有实现的函数sequence_mask,这个函数在nlp任务中经常用到,但是在pytorch中没有这个函数,这个函数可以按如下方式实现:
def sequence_mask(lengths, max_len=None):
lengths_shape = lengths.shape # torch.size() is a tuple
lengths = lengths.reshape(-1)
batch_size = lengths.numel()
max_len = max_len or int(lengths.max())
lengths_shape += (max_len,)
return (torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.unsqueeze(0).expand(batch_size, max_len)
.lt(lengths.unsqueeze(1))).reshape(lengths_shape)