import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
# from A_alexnet.tools.my_dataset import CatDogDataset
from A_alexnet.tools.my_dataset import CatDogDataset
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
def get_model(path_state_dict, vis_model=False):
"""
创建模型,加载参数
:param path_state_dict:
:return:
"""
model = models.alexnet()
pretrained_state_dict = torch.load(path_state_dict)
model.load_state_dict(pretrained_state_dict)
if vis_model:
from torchsummary import summary
summary(model, input_size=(3, 224, 224), device="cpu")
model.to(device)
return model
if __name__ == "__main__":
# config
data_dir = os.path.join(BASE_DIR, "..", "data", "train")
path_state_dict = os.path.join(BASE_DIR, "..", "data", "alexnet-owt-4df8aa71.pth")
num_classes = 2
MAX_EPOCH = 10 # 可自行修改
BATCH_SIZE = 128 # 可自行修改
LR = 0.001 # 可自行修改
log_interval = 1 # 可自行修改
val_interval = 1 # 可自行修改
classes = 2
start_epoch = -1
lr_decay_step = 1 # 可自行修改
# ============================ step 1/5 数据 ============================
norm_mean = [0.485, 0.456, 0.406
Alex net训练模型
最新推荐文章于 2025-01-22 18:46:46 发布

本文深入探讨如何使用Python进行深度学习,重点是AlexNet模型的训练过程。我们将讨论数据预处理、模型架构、损失函数以及优化器的选择,同时涵盖训练技巧和调参策略,帮助读者更好地理解和实现这一经典深度学习模型。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



