import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model import resnet34,resnet101
import matplotlib.pyplot as plt
# from csv import readerxon
import numpy as np
from osgeo import gdal
from torchvision.transforms import functional as F
# from torch.utils.tensorboard import SummaryWriter
def main():
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# device = torch.device("cuda:0")
print("using {} device.".format(device))
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]),
"test": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.3971, 0.4091, 0.3681], [0.2169, 0.1943, 0.1917])]),
}
data_root = os.path.abspath(os.path.join(os.getcwd(), "G:/splitdata")) # get data root path
image_path = os.path.join(data_root, "data") # data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
test_dataset = datasets.ImageFolder(root=os.path.join(image_path, "test"),
transform=data_transform["test"])
test_num = len(test_dataset)
test_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("using {} images for training, {} images for validation, {} images for testing".format(train_num, val_num,
test_num))
arry_train = []
arry_test = []
def plot_loss(arry_train):
line1, = plt.plot(range(0, len(arry_train)), arry_train, 'r.-')
plt_title = 'BATCH_SIZE = 16; EPOCH = 5'
plt.title(plt_title)
plt.legend(handles=[line1], labels=["train_loss", "test_loss"], loc="upper right", fontsize=7)
plt.ylabel('LOSS')
plt.show()
# net = ResNet34(classes_num=10)
net = resnet34()
model_weight_path = "./resnet34-333f7ec4.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
for param in net.parameters():
param.requires_grad = False
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 30)
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.001)
epochs = 20
best_acc = 0.0
save_path = './best.pth'
train_steps = len(train_loader)
total_test_step = 0
Loss_list = []
Accuracy_list = []
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
arry_train.append(loss)
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
Loss_list.append(running_loss / train_steps)
Accuracy_list.append(val_accurate)
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished')
x1 = range(0, 10)
x2 = range(0, 10)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('val accuracy')
plt.ylabel('val accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('training loss')
plt.ylabel('training')
plt.show()
plt.savefig("accuracy_loss.jpg")
if __name__ == '__main__':
main()
Pytorch画train loss和val acc曲线
最新推荐文章于 2024-06-16 08:00:00 发布