1.网络结构
![](https://img-blog.csdnimg.cn/270ddb50026e43868121a0fc17a0cd8d.png)
2.网络搭建
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.01)
nn.init.constant_(m.bias,0)
3.网络训练
import os
import sys
import json
from sympy import im
import torch
import torch.nn as nn
from torchvision import transforms,datasets,utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train":transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]),
"val":transforms.Compose([
transforms.Resize((224,224)),# cannot(224),must(224,224)
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
}
image_path = './flower_data'
train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val,key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(),batch_size if batch_size>1 else 0,8])# number of workers
print("Using {} dataloader workers every process".format(nw))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=nw
)
validate_dataset = datasets.ImageFolder(
root=os.path.join(image_path,"val"),
transform=data_transform["val"]
)
val_num=len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(
validate_dataset,
batch_size=4,
shuffle=False,
num_workers=nw
)
print("using {} images for training,{} images for validation.".format(train_num,val_num))
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
# def imshow(img):
# img = img / 2 + 0.5 # unnormalize
# npimg = img.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))
net = AlexNet(num_classes=5,init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.0002)
epochs = 10
save_path = './Alexnet.pth'
best_acc=0.0
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader,file=sys.stdout)
for step,data in enumerate(train_bar):
images,labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs,labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch+1,epochs,loss)
# validate
net.eval()
acc = 0.0 #accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader,file=sys.stdout)
for val_data in val_bar:
val_images,val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs,dim=1)[1]
acc+=torch.eq(predict_y,val_labels.to(device)).sum().item()
val_accurate = acc/val_num
print("[epoch %d] train_loss:%.3f val_accuracy:%.3f"%(epoch+1,running_loss/train_steps,val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(),save_path)
print("Finished Training!!!")
if __name__ == '__main__':
main()
4.网络检验
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# load image
img_path = "./tulip.jpg"
assert os.path.exists(img_path),"file:'{}' does not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N,C,H,W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img,dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path),"file:'{}' does not exist.".format(json_path)
with open(json_path,"r") as f:
class_indict = json.load(f)
# create model
model = AlexNet(num_classes=5).to(device)
# load model weights
weights_path = './Alexnet.pth'
assert os.path.exists(weights_path),"file:'{}' does not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()
with torch.no_grad():
# pridict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output,dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class:{} prob:{:.3f}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class:{:10} prob:{:.3}".format(class_indict[str(i)],predict[predict_cla].numpy()))
plt.show()
if __name__ == '__main__':
main()