Pytorch使用总结

设备的确定

os.environ[‘环境变量名称’]=‘环境变量值’

if cuda_used:
    os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

np.log10()

取以10为底的对数

Python的魔法方法__getitem__

Python的魔法方法__getitem__ 可以让对象实现迭代功能,这样就可以使用for…in… 来迭代该对象了。

class Animal:
    def __init__(self, animal_list):
        self.animals_name = animal_list

    def __getitem__(self, index):
        return self.animals_name[index]

animals = Animal(["dog","cat","fish"])
for animal in animals:
    print(animal)
dog
cat
fish

DataLoader()

pytorch对数据的传入经常使用DataLoader类,通过给DataLoader传入迭代器和batch_size的大小来实现。实现迭代器可以在数据类中实现__getitem__魔术方法,使其成为一个迭代器。

joblib.dump()与joblib.load()

非常好使,加载模型+保存模型

Counter()

Counter 是 dict 字典的子类,Counter 拥有类似字典的 key 键和 value 值,只不过 Counter 中的键为待计数的元素,而 value 值为对应元素出现的次数 count,为了方便介绍统一使用元素和 count 计数来表示。虽然 Counter 中的 count 表示的是计数,但是 Counter 允许 count 的值为 0 或者负值。

elements()方法

elements()方法返回一个迭代器,可以通过 list 或者其它方法将迭代器中的元素输出,输出的结果为对应出现次数的元素。

在torch中实现sequence_mask()功能

我们知道在tensorflow中有实现的函数sequence_mask,这个函数在nlp任务中经常用到,但是在pytorch中没有这个函数,这个函数可以按如下方式实现:

def sequence_mask(lengths, max_len=None):
    lengths_shape = lengths.shape  # torch.size() is a tuple
    lengths = lengths.reshape(-1)

    batch_size = lengths.numel()
    max_len = max_len or int(lengths.max())
    lengths_shape += (max_len,)

    return (torch.arange(0, max_len, device=lengths.device)
            .type_as(lengths)
            .unsqueeze(0).expand(batch_size, max_len)
            .lt(lengths.unsqueeze(1))).reshape(lengths_shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值