import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
# from A_alexnet.tools.my_dataset import CatDogDataset
from A_alexnet.tools.my_dataset import CatDogDataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
def get_model(path_state_dict, vis_model=False):
"""
创建模型,加载参数
:param path_state_dict:
:return:
"""
model = models.alexnet()
pretrained_state_dict = torch.load(path_state_dict)
model.load_state_dict(pretrained_state_dict)
if vis_model:
from torchsummary import summary
summary(model, input_size=(3, 224, 224), device="cpu")
model.to(device)
return model
if __name__ == "__main__":
# config
data_dir = os.path.join(BASE_DIR, "..", "data", "train")
path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")
num_classes = 2
MAX_EPOCH = 10 # 可自行修改
BATCH_SIZE = 128 # 可自行修改
LR = 0.001 # 可自行修改
log_interval = 1 # 可自行修改
val_interval = 1 # 可自行修改
classes = 2
start_epoch = -1
lr_decay_step = 1 # 可自行修改
# ============================ step 1/5 数据 ============================
norm_mean = [0.485, 0.456, 0.406
Alex net训练模型
最新推荐文章于 2024-07-31 15:55:27 发布
![](https://img-home.csdnimg.cn/images/20240711042549.png)