本文内容:添加SGD优化器|
目录
1.步骤一
新建修改train.py文件,添加如下代码:
import argparse
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
from blocks.Sopia import SophiaG
from blocks.lion_pytorch.lion_pytorch import Lion
from evaluate import evaluate
from models import UNet
from models.unet_model import UNetAttention1, UNetAttention2
from utils.data_loading import BasicDataset
from utils.dice_score import dice_loss
def train_model(
model,
device,
dir_img: Pat