神经网络-pytorch常用包介绍(一)
pytorch学习笔记:从数据处理、加载、模型建立、定义损失函数、定义反向传播方法、训练、保存模型、加载模型。
首先看一下常用的包:
from __future__ import print_function
import argparse
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
数据预处理
argparse 包
argparse 模块使编写用户友好的命令行界面变得更容易.程序只需定义好它要求的参数,然后argparse将负责如何从sys.argv中解析出这些参数。argparse模块还会自动生成帮助和使用信息并且当用户赋给程序非法的参数时产生错误信息。
主要由三个步骤:
- 创建 ArgumentParser() 对象
- 调用 add_argument() 方法添加参数
- 使用 parse_args() 解析添加的参数
import argparse
parser = argparse.ArgumentParser()#创建ArgumentParser() 对象
parser.add_argument('integer', type=int, help='display an integer')#调用add_argument()方法添加参数
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
args = parser.parse_args()使用 parse_args() #解析添加的参数
print(args.integer)
argparse添加参数主要是由类似于结构体组成。
parser.add_argument('integer', type=int, help='display an integer')
‘integer’ 相当于结构体名称,后面的type等为属性,包括默认值,帮助等。
argparse通过名称来使用它。
例如:
batch_size=args.batch_size
os包
python编程时,经常和文件、目录打交道,这是就离不了os模块。os模块包含普遍的操作系统功能,与具体的平台无关。
在读取文件和对文件进行分类时常用到。比如,读取神经网络输入数据,对数据分为训练集,验证集,测试集等。
1. n:25000(整个数据集的大小)
2. ratio = 0.2(验证集r的比重)
3. organize_datasets(path_to_data=’./train/’,n=n, ratio=ratio)
更详细的os包的使用请参见博客:python的os模块