pytorch框架
pytorch框架相关的知识
Vincent123Mei
这个作者很懒,什么都没留下…
展开
-
批数据训练
import torchimport torch.utils.data as Data # 处理数据的工具包# hyper parameterBATAC_SIZE = 5x = torch.linspace(1, 10, 10)y = torch.linspace(10, 1, 10)# print(x.type()) # torch.FloatTensor# 先转换成 torch 能识别的 Datasettorch_dataset = Data.TensorDataset(x, y原创 2020-07-21 21:28:06 · 210 阅读 · 1 评论 -
模型的保存与提取
保存torch.save(net1, 'net1.pkl') # 保存整个网络torch.save(net1.state(), 'net_params.pkl') # 只保存网络中的参数提取提取整个神经网络def restore_net(): net2 = torch.load('net1.pkl') # 提取网络 prediction = net2(x)提取网络参数def restore_params(): # 新建 net3 要求与原网络保持相同的结构 n原创 2020-07-21 20:36:52 · 300 阅读 · 0 评论 -
Numpy和Torch相互转换及其计算上的区别
Torch和Numpy十分相似将一个普通变量转换为tensordata = [[1, 2], [3, 4]]tensor_data = torch.FloatTensor(data)将tensor转换为numpytensor2numpy = tensor_data.numpy()将numpy转换为tensortensor_data = torch.from_numpu(np.data)numpy 中矩阵乘法numpy.matmul(data, data) #矩阵乘法data.dot原创 2020-07-19 20:43:14 · 556 阅读 · 0 评论