import torchvision
from torch import nn
import numpy as np
import os
import json
import pickle
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets
import torchvision.models as models
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
epochs = 10
lr = 0.03
batch_size = 32
image_path = 'D:/datasets'
save_path = 'D:/pycharm_my_data/best_model.pkl'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
train_dataset = datasets.ImageFolder(root=os.path.join(image_path),
transform=data_transform['train'])
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size,
True)
print('using {} images for training.'.format(len(train_dataset)))
cloth_list = train_dataset.class_to_idx
class_dict = {}
for key, val in cloth_list.items():
class_dict[val] = key
with open('class_dict.pk', 'wb') as f:
pickle.dump(class_dict, f)
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, pred, label):
exp = torch.exp(pred)
tmp1 = exp.gather(1, label.unsqueeze(-1)).squeeze()
tmp2 = exp.sum(1)
softmax = tmp1 / tmp2
log = -torch.log(softmax)
return log.mean()
model = torchvision.models.googlenet(weights=True)
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, 4)
model = model.to(device)
criterion = MyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_acc = 0
best_model = None
for epoch in range(epochs):
model.train()
running_loss = 0
epoch_acc = 0
epoch_acc_count = 0
train_count = 0
train_bar = tqdm(train_loader)
for data in train_bar:
images, labels = data
optimizer.zero_grad()
output = model(images.to('cuda'))
loss = criterion(output, labels.to('cuda'))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
epoch_acc_count += (output.argmax(axis=1) == labels.view(-1)).sum()
train_count += len(images)
epoch_acc = epoch_acc_count / train_count
print("【EPOCH: 】%s" % str(epoch + 1))
print("训练损失为%s" % str(running_loss))
print("训练精度为%s" % (str(epoch_acc.item() * 100)[:5]) + '%')
if epoch_acc > best_acc:
best_acc = epoch_acc
best_model = model.state_dict()
if epoch == epochs - 1:
torch.save(best_model, save_path)
print('Finished Training')
with open('class_dict.pk', 'rb') as f:
class_dict = pickle.load(f)
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img_path = "D:/datasets/birddata/Bananaquit/001.jpg"
img = Image.open(img_path)
img = data_transform(img)
plt.imshow(img.permute(1, 2, 0))
plt.show()
img = torch.unsqueeze(img, dim=0)
pred = class_dict[model(img).argmax(axis=1).item()]
print('【预测结果分类】:%s' % pred)
在这里插入代码片