(一)深度学习项目代码结构

1.代码结构

代码结构
参考链接:李宏毅2021年机器学习HW2 Phoneme Classification

2.代码细节

获得运行设备
这两种写法的返回值都是字符串

#check device
def get_device():
  return 'cuda' if torch.cuda.is_available() else 'cpu'
# 第二种写法
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

随机种子设置

# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

训练

# training
model.train() # set the model to training mode
for i, data in enumerate(train_loader):
   inputs, labels = data
   inputs, labels = inputs.to(device), labels.to(device)
   optimizer.zero_grad() 
   outputs = model(inputs) 
   batch_loss = criterion(outputs, labels)
   _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability
   batch_loss.backward() 
   optimizer.step() 

   train_acc += (train_pred.cpu() == labels.cpu()).sum().item()
   train_loss += batch_loss.item()

3.图像类深度学习算法代码结构

参考链接:李宏毅2021年机器学习HW3 CNN 水果图片分类
新加的好用的模块torchvision

import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder

3.1 数据导入

为训练集添加数据增强函数;
测试集和验证集不需要数据增强。

train_tfm = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 128)
    transforms.Resize((128, 128)),
    # You may add some transforms here.
    # ToTensor() should be the last one of the transforms.
    transforms.ToTensor(),
])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

3.2 DataLoader

batch_size = 128
# 导入的时候使用数据增强
# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

# Construct data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

3.3 第二种图像类Dataset格式

如果图片很多的时候,不能完全在Dataset中一次性导入,则只需要记录文件名和与之对应的label即可。只需要那三个函数写好了就行。参考语音识别的写法:
李宏毅2021年机器学习HW4 transformer 语音分类

3.4 tqdm设置进度条

tqdm是作用于DataLoader

for batch in tqdm(train_loader):
	# A batch consists of image data and corresponding labels.
	imgs, labels = batch
    # Forward the data. (Make sure data and model are on the same device.)
    logits = model(imgs.to(device))

写法二:

pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")
for step in range(10):
	pbar.update()
    pbar.set_postfix(
      loss=f"{0.1:.2f}",
      accuracy=f"{0.2:.2f}",
      step=step + 1,
    )
  • 4
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值