Omnimatte 项目使用教程
omnimatte项目地址:https://gitcode.com/gh_mirrors/om/omnimatte
1. 项目的目录结构及介绍
omnimatte/
├── data/
│ ├── __init__.py
│ ├── aligned_dataset.py
│ ├── base_dataset.py
│ ├── image_folder.py
│ ├── single_dataset.py
│ └── unaligned_dataset.py
├── models/
│ ├── __init__.py
│ ├── base_model.py
│ ├── networks.py
│ ├── omnimatte_model.py
│ └── utils.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ └── test_options.py
├── scripts/
│ ├── test.py
│ └── train.py
├── utils/
│ ├── __init__.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── README.md
├── requirements.txt
└── setup.py
目录结构介绍
data/
: 包含数据集处理的相关文件。models/
: 包含模型定义和网络结构的相关文件。options/
: 包含命令行选项和配置的相关文件。scripts/
: 包含训练和测试脚本。utils/
: 包含各种实用工具和辅助函数。README.md
: 项目说明文档。requirements.txt
: 项目依赖包列表。setup.py
: 项目安装脚本。
2. 项目的启动文件介绍
训练脚本
文件路径:scripts/train.py
# scripts/train.py
import os
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from util.util import save_images
import time
if __name__ == '__main__':
opt = TrainOptions().parse()
dataset = create_dataset(opt)
dataset_size = len(dataset)
print(f'The number of training images = {dataset_size}')
model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)
total_iters = 0
for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_iters += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data)
model.optimize_parameters()
if total_iters % opt.display_freq == 0:
save_result = total_iters % opt.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_iters % opt.print_freq == 0:
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
iter_data_time = time.time()
if epoch % opt.save_epoch