make train_demo和train.py阅读笔记

里面用到dataset.py

运行make train_demo命令

train_demo:
	python bin/train.py demo_data/filelist.txt output/bayer \
	--pretrained pretrained_models/bayer \
	--val_data demo_data/filelist.txt --batch_size 1

运行train.py ,后面是输入的参变量filelist.txt
预训练权重通常是 .npy文件,是numpy专用的二进制文件。
然后就是看train.py文件咯
各种导入模块

import argparse

#日志模块
import logging
import os

#更改进程名用的
import setproctitle

#和时刻相关的操作模块
import time

#torch系列模块
import numpy as np
import torch as th
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms

from torchlib.trainer import Trainer

#demasaic系列包
import demosaic.dataset as dset
import demosaic.modules as modules
import demosaic.losses as losses
import demosaic.callbacks as callbacks
import demosaic.converter as converter
import torchlib.callbacks as default_callbacks
import torchlib.optim as toptim

用getLogger方法间接实例化一个logger
log = logging.getLogger("demosaick")

接下来是主函数

def main(args, model_params):
  # 默认设置fix_seed是false
  if args.fix_seed:
  	# 设置了seed,相同seed下,每次生成的随机数序列都是一样的
    np.random.seed(0)
    # 为CPU设置种子用于生成随机数,以使得结果是确定的
    th.manual_seed(0)

  # ------------ Set up datasets ----------------------------------------------
  # 使用demosaic.dataset的ToTensor方法,,把dataset.py里面的sample给变成tensor
  # xform包含了masaic,mask和im,但是这几个都没有值

  xforms = [dset.ToTensor()]
  # 默认下不用
  if args.green_only:
    xforms.append(dset.GreenOnly())
  xforms = transforms.Compose(xforms)
  #默认下不用
  if args.xtrans:
    data = dset.XtransDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
  else:
    data = dset.BayerDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
  # 默认情况下data是一个空列表
  data[0]

  if args.val_data is not None:
    if args.xtrans:
      val_data = dset.XtransDataset(args.val_data, transform=xforms, augment=False)
    else:
      val_data = dset.BayerDataset(args.val_data, transform=xforms, augment=False)
  else:
    val_data = None
  # ---------------------------------------------------------------------------

  model = modules.get(model_params)
  log.info("Model configuration: {}".format(model_params))

  if args.pretrained:
    log.info("Loading Caffe weights")
    if args.xtrans:
      model_ref = modules.get({"model": "XtransNetwork"})
      cvt = converter.Converter(args.pretrained, "XtransNetwork")
    else:
      model_ref = modules.get({"model": "BayerNetwork"})
      cvt = converter.Converter(args.pretrained, "BayerNetwork")
    cvt.convert(model_ref)
    model_ref.cuda()
  else:
    model_ref = None

  if args.green_only:
    model = modules.GreenOnly(model)
    model_ref = modules.GreenOnly(model_ref)

  if args.subsample:
    dx = 1
    dy = 0
    if args.xtrans:
      period = 6
    else:
      period = 2
    model = modules.Subsample(model, period, dx=dx, dy=dy)
    model_ref = modules.Subsample(model_ref, period, dx=dx, dy=dy)

  if args.linear:
    model = modules.DeLinearize(model)
    model_ref = modules.DeLinearize(model_ref)

  name = os.path.basename(args.output)
  cbacks = [
      default_callbacks.LossCallback(env=name),
      callbacks.DemosaicVizCallback(val_data, model, model_ref, cuda=True, 
                                    shuffle=False, env=name),
      callbacks.PSNRCallback(env=name),
      ]

  metrics = {
      "psnr": losses.PSNR(crop=4)
      }

  log.info("Using {} loss".format(args.loss))
  if args.loss == "l2":
    criteria = { "l2": losses.L2Loss(), }
  elif args.loss == "l1":
    criteria = { "l1": losses.L1Loss(), }
  elif args.loss == "gradient":
    criteria = { 
      "gradient": losses.GradientLoss(), 
    }
  elif args.loss == "laplacian":
    criteria = { 
      "laplacian": losses.LaplacianLoss(), 
    }
  elif args.loss == "vgg":
    criteria = { "vgg": losses.VGGLoss(), }
  else:
    raise ValueError("not implemented")

  optimizer = optim.Adam
  optimizer_params = {}
  if args.optimizer == "sgd":
    optimizer = optim.SGD
    optimizer_params = {"momentum": 0.9}
  train_params = Trainer.Parameters(
      viz_step=100, lr=args.lr, batch_size=args.batch_size,
      optimizer=optimizer, optimizer_params=optimizer_params)

  trainer = Trainer(
      data, model, criteria, output=args.output, 
      params = train_params,
      model_params=model_params, verbose=args.debug, 
      callbacks=cbacks,
      metrics=metrics,
      valset=val_data, cuda=True)

  trainer.train()

运行主函数

#因为是作为脚本运行的,所以这个if条件为真
if __name__ == "__main__":
  # 创建 ArgumentParser() 对象
  parser = argparse.ArgumentParser()
  # I/O params
  # 设置要输入输入的参数
  parser.add_argument('data_dir')
  parser.add_argument('output')
  parser.add_argument('--val_data')
  parser.add_argument('--checkpoint')
  parser.add_argument('--pretrained')

  # Training
  # 设置训练网络参数
  parser.add_argument('--batch_size', type=int, default=16)
  parser.add_argument('--lr', type=float, default=1e-4)
  parser.add_argument('--fix_seed', dest="fix_seed", action="store_true")
  parser.add_argument('--loss', default="l2", choices=["l1", "vgg", "l2", "gradient", "laplacian"])
  parser.add_argument('--optimizer', default="adam", choices=["adam", "sgd"])

  # Monitoring
  parser.add_argument('--debug', dest="debug", action="store_true")

  # Model
  #dest是函数参数的名字,例如默认情况下,xtrans=false
  #action意思是一旦有这个参数,就将它设置为action的值
  parser.add_argument('--xtrans', dest="xtrans", action="store_true")
  parser.add_argument('--green_only', dest="green_only", action="store_true")
  parser.add_argument('--subsample', dest="subsample", action="store_true")
  parser.add_argument('--linear', dest="linear", action="store_true")
  parser.add_argument(
      '--params', nargs="*", default=["model=BayerNetwork"])

  # 设置了一些参数的默认值
  parser.set_defaults(debug=False, fix_seed=False, xtrans=False,
                      green_only=False, subsample=False, linear=False)
  
  # 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中
  args = parser.parse_args()

  # 定义一个空字典
  params = {}
  # 如果args实例不是空的
  if args.params is not None:
    for p in args.params:
      # string类中的方法,以=为分界,key给k,value给v
      k, v = p.split("=")
      # 查看v是不是数字,如果是,就转换成整数;
      if v.isdigit():
        v = int(v)
      # 判断v是不是bool型,如果是,转换成bool
      elif v == "False":
        v = False
      elif v == "True":
        v = True

	  # 将处理后的value和key加进空字典params里
      params[k] = v
  # 设置日志,显示process id,日志等级,文件名,行号,信息
  logging.basicConfig(
      format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
  # log中debug<info<warning<error<critial,若等级设置成info,则debug信息将无法被记录
  if args.debug:
    log.setLevel(logging.DEBUG)
  else:
    log.setLevel(logging.INFO)
  # os.path.basename(args.output)返回最后的文件名,不包括后缀,这里是bayer(在output文件夹下)
  # format用来格式化字符串,官方推荐,结果就是demosaic_bayer
  # setproctitle.setproctitle修改进程名字,不用这个语句,进程显示是python。用这个进程显示‘进程别名’。
  setproctitle.setproctitle(
      'demosaic_{}'.format(os.path.basename(args.output)))
  运行主函数
  main(args, params)

if name == “main”:的作用
python3 if x 和 if x is not None 区别
split用法的小试验
在这里插入图片描述
log的python官方科普
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
文件: import scrapy from demo1.items import Demo1Item import urllib from scrapy import log # BOSS直聘网站爬虫职位 class DemoSpider(scrapy.Spider): # 爬虫名, 启动爬虫时需要的参数*必填 name = 'demo' # 爬取域范围,允许爬虫在这个域名下进行爬取(可选) allowed_domains = ['zhipin.com'] # 爬虫需要的url start_urls = ['https://www.zhipin.com/c101280600/h_101280600/?query=测试'] def parse(self, response): node_list = response.xpath("//div[@class='job-primary']") # 用来存储所有的item字段 # items = [] for node in node_list: item = Demo1Item() # extract() 将xpath对象转换为Unicode字符串 href = node.xpath("./div[@class='info-primary']//a/@href").extract() job_title = node.xpath("./div[@class='info-primary']//a/div[@class='job-title']/text()").extract() salary = node.xpath("./div[@class='info-primary']//a/span/text()").extract() working_place = node.xpath("./div[@class='info-primary']/p/text()").extract() company_name = node.xpath("./div[@class='info-company']//a/text()").extract() item['href'] = href[0] item['job_title'] = job_title[0] item['sa 报错: C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\python.exe "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\pydevconsole.py" --mode=client --port=55825 import sys; print('Python %s on %s' % (sys.version, sys.platform)) sys.path.extend(['C:\\Users\\xieqianyun\\demo1', 'C:/Users/xieqianyun/demo1']) Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 17:00:18) [MSC v.1900 64 bit (AMD64)] Type 'copyright', 'credits' or 'license' for more information IPython 7.10.0 -- An enhanced Interactive Python. Type '?' for help. PyDev console: using IPython 7.10.0 Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 17:00:18) [MSC v.1900 64 bit (AMD64)] on win32 runfile('C:/Users/xieqianyun/demo1/demo1/begin.py', wdir='C:/Users/xieqianyun/demo1/demo1') Traceback (most recent call last): File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\IPython\core\interactiveshell.py", line 3319, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-2-fc5979762143>", line 1, in <module> runfile('C:/Users/xieqianyun/demo1/demo1/begin.py', wdir='C:/Users/xieqianyun/demo1/demo1') File "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "C:\Users\xieqianyun\PyCharm Community Edition 2019.2.5\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "C:/Users/xieqianyun/demo1/demo1/begin.py", line 3, in <module> cmdline.execute('scrapy crawl demo'.split()) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\cmdline.py", line 145, in execute cmd.crawler_process = CrawlerProcess(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 267, in __init__ super(CrawlerProcess, self).__init__(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 145, in __init__ self.spider_loader = _get_spider_loader(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\crawler.py", line 347, in _get_spider_loader return loader_cls.from_settings(settings.frozencopy()) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 61, in from_settings return cls(settings) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 25, in __init__ self._load_all_spiders() File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\spiderloader.py", line 47, in _load_all_spiders for module in walk_modules(name): File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\site-packages\scrapy\utils\misc.py", line 73, in walk_modules submod = import_module(fullpath) File "C:\Users\xieqianyun\AppData\Local\Programs\Python\Python36\lib\importlib\__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "<frozen importlib._bootstrap>", line 994, in _gcd_import File "<frozen importlib._bootstrap>", line 971, in _find_and_load File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 665, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 678, in exec_module File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed File "C:\Users\xieqianyun\demo1\demo1\spiders\demo.py", line 4, in <module> from scrapy import log ImportError: cannot import name 'log'
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值