import os.path
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
model = models.googlenet(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "./model/googleNet_cuda2.pth"
model = models.googlenet().to(device)
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path))
def train():
model.train()
# summary(model=model, input_size=(3, 227, 227), batch_size=1, device="cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
])
train = ImageFolder("./flower_datas/train", transform=transform)
# test = ImageFolder("./flower_datas/val
11-18
6379
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
05-13
83万+
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
07-30
2472
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
07-10
553
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
12-18
1382
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)