1.代码结构
参考链接:李宏毅2021年机器学习HW2 Phoneme Classification
2.代码细节
获得运行设备
这两种写法的返回值都是字符串
#check device
def get_device():
return 'cuda' if torch.cuda.is_available() else 'cpu'
# 第二种写法
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
随机种子设置
# fix random seed
def same_seeds(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
训练
# training
model.train() # set the model to training mode
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
batch_loss = criterion(outputs, labels)
_, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
batch_loss.backward()
optimizer.step()
train_acc += (train_pred.cpu() == labels.cpu()).sum().item()
train_loss += batch_loss.item()
3.图像类深度学习算法代码结构
参考链接:李宏毅2021年机器学习HW3 CNN 水果图片分类
新加的好用的模块torchvision
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder
3.1 数据导入
为训练集添加数据增强函数;
测试集和验证集不需要数据增强。
train_tfm = transforms.Compose([
# Resize the image into a fixed shape (height = width = 128)
transforms.Resize((128, 128)),
# You may add some transforms here.
# ToTensor() should be the last one of the transforms.
transforms.ToTensor(),
])
test_tfm = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
3.2 DataLoader
batch_size = 128
# 导入的时候使用数据增强
# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
3.3 第二种图像类Dataset格式
如果图片很多的时候,不能完全在Dataset中一次性导入,则只需要记录文件名和与之对应的label即可。只需要那三个函数写好了就行。参考语音识别的写法:
李宏毅2021年机器学习HW4 transformer 语音分类
3.4 tqdm设置进度条
tqdm
是作用于DataLoader
的
for batch in tqdm(train_loader):
# A batch consists of image data and corresponding labels.
imgs, labels = batch
# Forward the data. (Make sure data and model are on the same device.)
logits = model(imgs.to(device))
写法二:
pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")
for step in range(10):
pbar.update()
pbar.set_postfix(
loss=f"{0.1:.2f}",
accuracy=f"{0.2:.2f}",
step=step + 1,
)