pytorch学习

torch数据类型转换

torch.from_numpy

作用:
获取一个大小与numpy数组相同的张量
用法:

npa = numpy.array([1, 0, 1])
torch_tensor = torch.from_numpy(npa)

torch变量变换数据类型

类型转换示例,tensor为torch张量。

tensor = tensor.int()
tensor = tensor.to(torch.float32)

torch.tensor操作

torch.where()

利用torch.where可根据条件对tensor中的数据进行赋值,如下面语句就是以0.5为阈值,将tensor中大于0.5的值赋值为1,小于等于0.5的赋值为0

tensor = torch.where(tensor > 0.5, 1, 0)

也可以是使用两个tensor对其进行赋值,下面语句中的a和b皆为tensor,其语句的意思为,如果a的元素大于0,则result对应位置的值用a的值代替,否则用b的值代替。

result = torch.where(a > 0, a, b)

torch.sum()

对tensor中的所有数据进行求和

sum = torch.sum(tensor)
sum = tensor.sum()

提取指定维度数据:

提取指定一列数据,i为需要提取的列数据

temp_data = data[:, i]

torch.utils.data:

TensorDataset

TensorDataset实现对tensor的打包,通过tensor的第一维度进行索引,因此输入的两个tensor的第一维度必须相等,我一般用来他将data和label组装成dataset,方便后续按batch_size进行划分用于模型训练。

dataset = TensorDataset(data, label)

ConcatDataset合并数据集

new_dataset = ConcatDataset([dataset_a, dataset_b])

random_split()

封装dataset后,使用下述语句对dataset进行划分。

train_dataset, test_dataset = random_split(
        dataset=dataset,
        lengths=[len_train, len(dataset) - len_train],
        generator=torch.Generator().manual_seed(0)
    )

dataset为输入的已用TesorDataset封装好的dataset,lengths控制训练集和测试集的长度,固定genreator有助于复现。一般建议先使用random_split划分为train_dataset, test_dataset,再对其使用Dataloader进行batch_size的划分,以避免出来先划分Dataloader后训练时出现无法遍历train_dataloader的问题。

DataLoader

常用形式:

data_loader = DataLoader(dataset=data, batch_szie=batch_size, shuffle=True, drop_last=True)

上述代码表示使用data作为划分的数据集,batch_size大小设为bacth_size,对数据进行打乱,其打乱方式为先打乱,后划分batch,drop_last表示是否样本数无法被batch_size整除时,是否将无法被整除的删除。

torch模型操作

torch保存&加载模型

模型保存和加载方法有以下两种:

 # 保存模型参数,速度快
 torch.save(model.state_dict(), path) 
 # 加载
 model = model.load_state_dict(torch.load(path))
 # 保存模型整个网络
 torch.save(net, path)
 # 加载
 model = torch.load(ptah)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

White Jiang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值