资源地址 链接:https://pan.baidu.com/s/1H2tQwYP1i-VIziS6XbLllQ 提取码:q5mt import torchvision from torch.utils.data import DataLoader import torch import time from PIL import Image # 搭建模型 model = torchvision.models.vgg19(pretrained=True) model.classifier[-1] = torch.nn.Linear(4096, 2) print(model) # 初始化运行条件 if True: sp = '\n' + '--------' * 20 + '\n' root = './data' train_path = root + '/train' test_path = root + '/test' bs = 16 lr = 0.0001 epoch = 20 device = 'cuda' criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(params=model.parameters(), momentum=0.9, lr=lr) transform = torchvision.transforms.Compose([ torchvision.transforms.RandomResizedCrop(224), torchvision.transforms.ToTensor() ]) # 读取数据 if True: train_data = torchvision.datasets.ImageFolder(train_path, transform) classes = train_data.classes train_iterator = DataLoader(train_data, bs, shuffle=True) # 展示数据细节 if False: print(train_data, end=sp) print('class_to_idx: ', train_data.class_to_idx) print('classes: ', train_data.classes) print('extension: ', train_data.extensions) print('extra_repr: ', train_data.extra_repr()) print('imgs: ', train_data.imgs) print('loader: ', train_data.loader) print('root: ', train_data.root) print('samples: ', train_data.samples) print('target_transform: ', train_data.target_transform) print('targets: ', train_data.targets) print('transform: ', train_data.transform) print('transforms: ', train_data.transforms) print('\n\n', end=sp) print(train_iterator, end=sp) print('batch_sampler: ', train_iterator.batch_sampler) print('batch_size: ', train_iterator.batch_size) print('collate_fn: ', train_iterator.collate_fn) print('dataset: ', train_iterator.dataset) print('drop_last: ', train_iterator.drop_last) print('generator: ', train_iterator.generator) print('multiprocessing_context: ', train_iterator.multiprocessing_context) print('num_workers: ', train_iterator.num_workers) print('persistent_workers: ', train_iterator.persistent_workers) print('pin_memory: ', train_iterator.pin_memory) print('prefetch_factor: ', train_iterator.prefetch_factor) print('sampler: ', train_iterator.sampler) print('timeout: ', train_iterator.timeout) print('worker_init_fn: ', train_iterator.worker_init_fn) print('\n\n', end=sp) def train(model, iterator, optimizer, criterion): def accuracy(outputs, label): pre = torch.argmax(outputs, dim=1) acc_num = (pre == label).sum() return acc_num / len(label) start_time = time.monotonic() epoch_loss = 0.0 epoch_acc = 0.0 model = model.to(device) model.train() for (images, labels) in iterator: optimizer.zero_grad() images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) acc = accuracy(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss epoch_acc += acc cost_time = time.monotonic() - start_time return epoch_loss / len(iterator), epoch_acc / len(iterator), cost_time if __name__ == '__main__': # 是否训练 if False: for epoch in range(epoch): loss, acc, cost_t = train(model, train_iterator, optimizer, criterion) print(f'epoch: {epoch}\tcost time: {cost_t}\nloss: {loss}\tacc: {acc}') torch.save(model.state_dict(), 'cat_dog_classification.pth') model.load_state_dict(torch.load('cat_dog_classification.pth')) # 是否选择图片进行预测 choice = input('是否选择图片进行预测? ') if choice in {'y', 'Y'}: path = input('输入图片路径(仅限于jpg): ') def classification(path): image = transform(Image.open(path)) image = torch.unsqueeze(image, 0) out = model(image) poss = torch.softmax(out, dim=1) index = int(torch.argmax(out, dim=1)) print('name: ', classes[index], '\tpossibility: ', float(poss[0, index])) classification(path)
猫狗二分类
最新推荐文章于 2023-10-19 17:45:57 发布