model.py(两个类;四个函数)
class DoubleConv ( nn. Module) :
def __init__ ( )
def forward ( )
class UNET ( nn. Module)
def __init__ ( )
def forward ( )
train.py(两个函数)
def train_fn ( loader, model, optimizer, loss_fn, scaler)
for batch_idx, ( data, targets) in enumerate ( loop) :
def main ( )
train_transform
val_transforms
model = UNET( in_channels= 3 , out_channels= 1 ) . to( DEVICE)
train_loader, val_loader = get_loaders( )
for epoch in range ( NUM_EPOCHS) :
train_fn( )
save_checkpoint( )
check_accuracy( loader, model, device= "cude" )
dataset.py(一个类;三个函数)
class CarvanaDataset ( Dataset) :
def __init__ ( self, image_dir, mask_dir, transform = None ) :
def __len__ ( self) :
def __getitem__ ( self, index) :
utils.py(多个函数)
def save_checkpoint ( ) :
def load_checkpoint ( checkpoint, model) :
def get_loaders (
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers= 4 ,
pin_memory= True
) :
def check_accuracy ( loader, model, device= "cude" ) :
def save_predictions_as_imgs ( ) :