import torch.nn as nn
import torch
import os
import torchvision
from PIL import Image
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision import transforms
from tqdm import tqdm
from torchvision.models import mobilenet,MobileNet_V2_Weights
def train_data():
# model
device = "cuda:0"
# model = resnet18(weights = ResNet18_Weights.IMAGENET1K_V1) # 11689512
model = mobilenet.mobilenet_v2(weights = MobileNet_V2_Weights.IMAGENET1K_V1)
model.fc = nn.Sequential(
nn.Linear(1280,512),
nn.Linear(512,66) # 最后的分类个数根据自己数据来
)
batch_size = 16
model.to(device)
#data
tf = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x :x.repeat(3,1,1)), # 针对单通道数据的扩增
transforms.Resize((224,224)), #适应预训练模型的输入大小
transforms.Normalize([.5,.5,.5],[.5,.5,.5])
]
)
train_set = torchvision.datasets.EMNIST(root = "./data",split = "byclass",transform= tf,train = True,download = True)
train_loader = DataLoader(train_set,batch_size = batch_size,shuffle = True)
print(len(train_set))
test_set = torchvision.datasets.EMNIST(root = "./data",split = "byclass",transform=tf,train = False,download = True)
test_loader = DataLoader(test_set,batch_size = batch_size,shuffle = True)
print(len(test_set))
n_epoch = 100
lr = 0.0003
op = optim.Adam(params=model.parameters(), lr=lr, weight_decay=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(op,step_size = 1,gamma = 0.94) # 调整学习率
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.to(device)
for ep in range(n_epoch):
pbar = tqdm(train_loader)
model.train()
j = 0
for x,c in pbar:
if j >=100:
break
x = x.to(device)
c = c.to(device)
out = model(x)
# pred = torch.argmax(out,dim = 1)
loss = criterion(out,c)
op.zero_grad()
loss.backward()
op.step()
pbar.set_description(f"epoch:{ep},loss:{loss.item():.4f}")
j += 1
model.eval()
pbar_test = tqdm(test_loader)
correct_all = 0
all = 0
j = 0
for x,c in pbar_test:
if j >=100:
break
x = x.to(device)
c = c.to(device)
out = model(x)
pred = torch.argmax(out,dim = 1)
correct = (pred==c).sum().cpu().numpy()
correct_all += correct
all += c.shape[0]
pbar_test.set_description(f"epoch:{ep},correct:{correct}")
j += 1
print(f"ep:{ep},acc:{correct_all/all : .3f},correct:{correct_all},all:{all}")
lr_scheduler.step()
if __name__ == "__main__":
train_data()
03-19
6978
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
07-03
1万+
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
11-02
741
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)