PyTorch学习(一)

1. Tensor(张量)
pytorch中的张量与numpy的ndarray相对应,都表示一个多维的矩阵,张量与ndarray可以相互转化。
Tensor有五种数据类型:32位浮点型torch.FloatTensor、64位浮点型torch.DoubleTensor、16位整型torch.ShortTensor、32位整型torch.IntTensor、64位整型torch.LongTensor
可以通过以下方式定义张量:

a = torch.Tensor([[1,2],[2,3],[3,4]])#默认是32位浮点型

b = torch.LongTensor([[1,2],[2,3],[3,4]])#长整型

c = torch.zeros((3,2))#3*2的零矩阵

d = torch.randn((3,3))#3*3的随机矩阵

Tensor可以像numpy一样用索引找到元素,也可以进行转化。

a[0,1] = 100 #设置第一行第二列的值为100

numpy_b = b.numpy()#Tensor转化为numpy

e = numpy.array([[1,1],[2,2])
torch_e = torch.from_numpy(e)#numpy转化为Tensor
f_torch_e =  torch_e.float()#改变数据类别

2. Variable(变量)
变量是神经网络计算图中的一个概念,Variable提供了一个自动求导功能,变量与张量本质上没有什么不同,但变量会被放入一个计算图中,然后进行自动求导。
Variable在torch.autograd.Variable中,一个Tensor转化为Variable可以直接使用Variable()函数。
Variable有三个属性datagradgrad_fn
data可以将Variable中的tensor数值取出,grad_fn代表着得到此变量的操作,grad_fn是这个Variable的反向传播梯度。
Variable反向传播,获得梯度函数如下(ps:看了几遍没看懂,懂了再回来改吧)

x = torch.randn([1])
x = Variable(x, requires_gard=True)
x.backward()

3. Dataset(数据集)
数据集是任何机器学习都需要的,PyTorch中可以定义通过继承抽象类torch.utils.data.Dataset来实现自己的数据集。其中必须要定义__len__()和__getitem__()两个函数。

class myDataset(Dataset):
	def __init__(self, csv_file, txt_file, root_dir, other_file):
		self.csv_data = pd.read_csv(csv_file)
		with open(txt_file, 'r') as f:
			datalist = f.readline()
		self.txt_data = datalist
		self.root_dir = root_dir
	
	def __len__(self):#返回数据集的大小
		return len(self.csv_data)
		
	def __getitem__(self, idx):#根据下标返回数据
		data = (self.csv_data[]idx], self.txt_data[idx])
		return data

PyTorch还可以使用loader来进行多线程的读取数据,通过torch.utils.data.DataLoader来定义。

4. nn.Module(模组)
PyTorch中建立神经网络,所有的层结构与损失函数都来自于torch.nn,所有的模型都要继承这个基类nn.Module。

class net_name(nn.Module):
	def __init__(self, other_arguments):
		super(net_name,self).__init__()
		self.convl = nn.Conv2d(in_channels, out_channels, kernel_size)
	
	def forward(self, x ):
		x = self.convl(x)
		return x

5. 模型的保存和加载
保存模型的方式有两种:1.整个模型保存,2.仅仅保存模型的参数。

torch.save(model, './model.pth')#整个保存
torch.save(modle, './model_state.pth')#保存参数

模型的加载也是这两种。

load_model = torch.load('model.pth')

mode.load_state_dic(torch.load('model_state.pth'))
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值