label = 1 if ‘dog’ in img_path.split(‘/’)[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def len(self):
return len(self.imgs)
然后我们在train.py调用DogCat读取数据
dataset_train = DogCat(‘data/train’, transforms=transform, train=True)
dataset_test = DogCat(“data/train”, transforms=transform_test, train=False)
读取数据
print(dataset_train.imgs)
导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
设置模型
====
使用CrossEntropyLoss作为loss,模型采用alexnet,选用预训练模型。更改全连接层,将最后一层类别设置为2,然后将模型放到DEVICE。优化器选用Adam。
实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = alexnet(pretrained=True)
model_ft.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 2),
)
model_ft.to(DEVICE)
选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
def adjust_learning_rate(optimizer, epoch):
“”“Sets the learning rate to the initial LR decayed by 10 every 30 epochs”“”
modellrnew = modellr * (0.1 ** (epoch // 50))
print(“lr:”, modellrnew)
for param_group in optimizer.param_groups:
param_group[‘lr’] = modellrnew
设置训练和验证
=======
定义训练过程
def train(model, device, train_loader, optimizer, epoch):
model.train()
sum_loss = 0
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data).to(device), Variable(target).to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print_loss = loss.data.item()
sum_loss += print_loss
if (batch_idx + 1) % 50 == 0:
print(‘Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}’.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
-
- (batch_idx + 1) / len(train_loader), loss.item()))
ave_loss = sum_loss / len(train_loader)
print(‘epoch:{},loss:{}’.format(epoch, ave_loss))
验证过程
def val(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
<

最低0.47元/天 解锁文章
&spm=1001.2101.3001.5002&articleId=137705527&d=1&t=3&u=4fa7d2ca2126460d8477116c179b186b)
2586

被折叠的 条评论
为什么被折叠?



