Introduction
学习使用了深度学习Pytorch差不多一年半多了,今天在这里呢,来总结一些不怎么使用的细节方面的代码知识点,为以后方便查找和使用。
1、Pytorch模型调用相关函数
保存模型:
#保存整个模型,model为你的神经网络
torch.save(model,'save.pt')#.pt 或.pth形式
#只保存训练好的权重:
torch.save(model.state_dict(), 'save.pt')
加载模型使用:
model=cnn() #cnn()这里表示为你定义的神经网络模型
model.eval() #这里表示为测试使用
#第一种:加载整个模型
(命名变量) = torch.load("save.pt")
#第二种:加载整个模型参数
model.load_state_dict(torch.load(param_file)) # load模型
若当你的模型或者你调用他人的模型为在gpu上训练的,但你想cpu上载入时:
model.load_state_dict(t.load(params_file, map_location='cpu'))
2.Pytorch模型测试与单一数据使用
当你需要单独地一个数据来测试时,你可以使用下面的代码:
img0 = Image.open(img0_tuple[0])
img0 = img0.convert("L")#图像转换黑白
img0 = transform(img0) #这里需要跟你训练数据中的transform一样
img0= Variable(img0.unsqueeze(0)) #添加维度(eg:3->4)
output1 = model(Variable(img0) #Vaariable可求导变量
在这里,我需要强调以下为什么要添加维度,首先我们要知道由于在训练时,我们喂给神经网络的数据格式时这样的:(batch_size,in_channel,height,width),因此呢,我们单一加载一个数据时,也需要满足四个维度。
3、训练结果数据可视化
#首先定义两组数组
loss_history= [] #用来保存训练过程中的loss值
Accuracy_history= [] #用来保存训练时的准确率值
epoch_number = 0
loss_history.append(loss_contrastive.item())
Accuracy_history.append((accuray.item())
epoch_number += 1
#这里保存的数据值为每一次epoch时保存一次.
#注意:当结束完训练时,跳出训练循环时
show_plot(epoch_number,loss_history,Accuracy_history)
#show_plot为自己定义的函数
其中show_plot函数定义如下:
def show_plot(iteration, loss,Accuracy_history):
plt.plot(iteration, loss,color='red',abel="loss_history")
plt.plot(iteration,Accuracy_history,abel="Accuracy_history")
plt.ylabel('loss')
plt.xlabel('Number of training')
plt.show()
当然了,你要是想用上位机界面显示效果的话,可以采用PyQT,这里你可以参考我的另一篇博客
4、数据转换
numpy中的ndarray转化成pytorch中的tensor : torch.from_numpy()
pytorch中的tensor转化成numpy中的ndarray : numpy()
例如:
将Variable张量转化为numpy:
V_data = V_data.data.numpy()
import numpy as np
import torch
np_arr = np.array([1,2,3,4])
tensor_=torch.from_numpy(np_arr)
ten_numpy=tensor_.numpy()
这里需要注意的是,你定义的numpy数组或者tensor使用了Variable时,想要转换为tensor或者numpy时,例如:
#mydata为Variable(tensor)类型时,转换为numpy
mydata=mydata.data.numpy()
5、图像数据读取、类型与转换
5.1、cv2库:
import cv2
img=cv2.imread('xxx.jpg')#加载图片
cv2.imshow('src',img)#显示图片
#img。shape为获取图像numpy数据格式
#img.size为图像数据像素值个数
#img.dtype为numpy中数据类型
cv2读取图片为一个numpy矩阵,彩色图片维度是(高度,宽度,通道数),数据类型是uint8。若获取的为灰度图,读进来的灰度图的矩阵格式是(高度,宽度)。
cv2中图片矩阵变换:
img = img.transpose(2,0,1)
#若读取格式为(3,124,224,则上面的转换为(224,3,124)
5.2、PIL.Image.open
from PIL import Image
img = Image.open('xxx.jpg')
print(img.format)
print(img.size) #注意,省略了通道 (w,h)
print(img.mode) #L为灰度图,RGB为真彩色,RGBA为加了透明通道
img.show() # 显示图片
Image.open()读取图片返回的是一个对象(JPEG),特别的有些很有意思的执行语句:
import np
#复制图像
img = img.copy()
#灰度图像获取
gray = Image.open('xxx.jpg').convert('L')
#转化为矩阵形式
arr=np.array(img)
#矩阵再转为图像
new_im = Image.fromarray(arr)
#分离合并通道
r, g, b = img.split()
img = Image.merge("RGB", (b, g, r))
5.3、matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import numpy as np
image = plt.imread('xxx.jpg')
plt.imshow(image)
plt.show()
值得注意的是plt.imread读入的也是一个矩阵,但跟cv2不一样的是,彩图读进的是RGB,格式为 (h,w,c),还有就是存储图像时,注意,必须在show之前savefig,否则存储的图片一片空白
6、异常处理
当我们编写代码的时候,或者编写网络时候出现可以忽略的异常IOError,但我们又不想结束程序,这时候我们可以捕捉它,做异常处理
try:
#此处编写原来想要的程序代码块
except IOError:#(IOErro为一个异常,可以替换其他异常种类)
print('fail to do something!')
7、常见的激活函数
无论是在学习也好,找工作笔试也好,这些都是最基本的问题,下面跟着我的思路来走。
(1)question one: 什么是激活函数
(2)question two: 为什么要用激活函数
(3)question three: 常见的三大类激活函数是什么
8、学习率的下降方式
9、减少过拟合的方法
续
有待补充。。。。