02人脸表情识别
简介:使用VGG模型进行人脸表情识别
参考链接:深度学习100例 | 第2例:人脸表情识别 - PyTorch实现_pytorch 表情识别-CSDN博客
代码链接:02人脸表情识别 (github.com)
数据集链接:facial emotion recognition (kaggle.com)
1.数据预处理
下载数据集
直接用kaggle的api下载数据集
#copy api command
!kaggle datasets download -d chiragsoni/ferdata
#解压数据
!unzip ferdata.zip -d /content/ferdata
数据读取与数据预处理
train_data='/content/ferdata/train'
test_data='/content/ferdata/test'
import torch
from torchvision import transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
train_transforms=transforms.Compose([
transforms.Resize([48,48]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,456,0.406],
std=[0.229,0.224,0.225]
)
])
test_transforms=transforms.Compose([
transforms.Resize([48,48]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485,456,0.406],
std=[0.229,0.224,0.225]
)
])
train_data=ImageFolder(train_data,transform=train_transforms)
test_data=ImageFolder(test_data,transform=test_transforms)
train_loader=DataLoader(train_data,batch_size=16,shuffle=True,num_workers=1)
test_loader=DataLoader(test_data,batch_size=16,shuffle=True,num_workers=1)
print("The number of images in a training set is: ", len(train_loader)*16)
print("The number of images in a test set is: ", len(test_loader)*16)
print("The number of batches per epoch is: ", len(train_loader))
classes = ('Angry', 'Fear', 'Happy', 'Surprise')
一般需要了解数据集的哪些信息?如何查看?
步骤:处理数据-读取数据-包装数据
- transforms.compose
- datasets.ImageFolder
- torch.utils.data.DataLoader
dataloader:数据加载器
作用:从数据集随机加载数据,拼接成一个batch,实现迭代器,可以让使用时,迭代获取数据内容
2.VGG-16模型
from torch.nn import functional as F
device="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device!".format(device))
from torchvision import models
#直接调用官方封装好的vgg模型
model=models.vgg16(pretrained=True)
model
3.定义损失函数和优化器
from torch import nn
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.0001)
4.定义训练、测试函数
def train(dataloader,model,loss_fn,optimizer):
model=model.to(device)
model.train()
for i,(images,labels) in enumerate(dataloader,0):
images=images.to(device)
labels=labels.to(device)
outputs=model(images)
loss=loss_fn(outputs,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%1000==0 :
print('[%5d] loss :%.3f' % (i,loss))
def test(dataloader,model,loss_fn):
size=len(dataloader.dataset)
num_batches=len(dataloader)
model.eval()
test_loss,correct=0,0
with torch.no_grad():
for X,y in dataloader:
X,y=X.to(device),y.to(device)
pred=model(X)
test_loss+=loss_fn(pred,y).item()
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
test_loss/=num_batches
correct/=size
print(f"Test Error:\n Accuracy:{(100*correct):>0.1f}%,Avg loss:{test_loss:>8f} \n")
return correct,test_loss
如何理解for i,(images,labels) in enumerate(dataloader)?
for…in enumerate() 会调用DataLoader类的__iter__方法,将一个batch的数据以数组的形式返回,若总数居为6,batch=2,则要迭代3次,i=0,1,2
https://blog.csdn.net/qq_32938525/article/details/115588656
为什么是enumerate(dataloader,0)而不是enumerate(dataloader)?
==>enumerate(dataloader,start=0),默认从0开始
这里的images和labels的tensor类型为啥要转化为variable类型?
语法解释:
- 在Python中,
print(f" ")
是格式化字符串(f-string)的语法,它允许你在字符串中嵌入表达式,这些表达式在运行时会被其值所替换。f 或 F 前缀表示这是一个格式化字符串字面量。
在 f’’ 或 F’’ 中的大括号 {} 内,你可以放入任何有效的Python表达式。当 print 函数执行时,这些表达式会被求值,并且其结果会被插入到字符串的相应位置。
如:
# 基本用法
acc=0.97
print(f"accuracy:{acc}")
>>accuracy:0.97
#格式化数字(一般可以用于保留浮点数的小数点后几位。)
pi = 3.141592653589793
print(f"The value of pi is approximately {pi:.2f}.")
>>The value of pi is approximately 3.14.
# :.2f 是一个格式说明符,它告诉Python将浮点数 pi 格式化为带有两位小数的字符串。
{(100*correct):>0.1f}
:表示将100*correct
这个表达式的结果格式化为浮点数,并且保留一位小数。>
表示右对齐,0
表示用0填充空白位置。{test_loss:>8f}
:表示将test_loss
这个变量格式化为浮点数,并且保留8位字符的宽度,右对齐。
5.训练
test_acc_list = []
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_loader,model,loss_fn,optimizer)
test_acc,test_loss = test(test_loader,model,loss_fn)
test_acc_list.append(test_acc)
print("Done!")
#可视化
import numpy as np
import matplotlib.pyplot as plt
x = [i for i in range(1,11)]
plt.plot(x, test_acc_list, label="line ACC", alpha=0.8)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()
acc=0.89
loss=1.821639
print(f"Accuracy:{(acc*100):>.1f}%,Loss:{loss:>.3f}")