DAN源码解读(龙明盛Xlearn)

代码来自https://github.com/thuml/Xlearn/tree/master/pytorch
摘要由CSDN通过智能技术生成

目录

 

main函数

transfer_classification()

prep.image_train()

prep.image_test_10crop()

prep.image_test()

image_classification_test()

补充

util_data.DataLoader()(源码分析)

nn.Linear()

nn.Sequential()


main函数

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Transfer Learning')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--source', type=str, nargs='?', default='amazon', help="source data")
    parser.add_argument('--target', type=str, nargs='?', default='webcam', help="target data")
    parser.add_argument('--loss_name', type=str, nargs='?', default='JAN', help="loss name")
    parser.add_argument('--tradeoff', type=float, nargs='?', default=1.0, help="tradeoff")
    parser.add_argument('--using_bottleneck', type=int, nargs='?', default=1, help="whether to use bottleneck")
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 

    config = {}
    config["num_iterations"] = 20000
    config["test_interval"] = 500
    config["prep"] = [{"name":"source", "type":"image", "test_10crop":True, "resize_size":256, "crop_size":224},
                      {"name":"target", "type":"image", "test_10crop":True, "resize_size":256, "crop_size":224}]
    config["loss"] = {"name":args.loss_name, "trade_off":args.tradeoff }
    config["data"] = [{"name":"source", "type":"image", "list_path":{"train":"../data/office/"+args.source+"_list.txt"}, "batch_size":{"train":36, "test":4} },
                      {"name":"target", "type":"image", "list_path":{"train":"../data/office/"+args.target+"_list.txt"}, "batch_size":{"train":36, "test":4} }]
    config["network"] = {"name":"ResNet50", "use_bottleneck":args.using_bottleneck, "bottleneck_dim":256}
    config["optimizer"] = {"type":"SGD", "optim_params":{"lr":1.0, "momentum":0.9, "weight_decay":0.0005, "nesterov":True},
                           "lr_type":"inv", "lr_param":{"init_lr":0.0003, "gamma":0.0003, "power":0.75} }
    print(config["loss"])
    transfer_classification(config)

transfer_classification()

def transfer_classification(config):
    ## set pre-process
    prep_dict = {}
    for prep_config in config["prep"]:
        prep_dict[prep_config["name"]] = {}   # pre_dict["source"] = {}
        if prep_config["type"] == "image":   # True
            # prep_dict["source"]["test_10crop"] = True
            prep_dict[prep_config["name"]]["test_10crop"] = prep_config["test_10crop"]
            # prep_dict["source"]["train"] = prep.image_train(resize_size = 256, crop_size = 224)
            # 把test的参数设置好:图片大小/随机切割/随机反转/正则化/totensor
            prep_dict[prep_config["name"]]["train"]  = prep.image_train(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])

            if prep_config["test_10crop"]:      # True
                # 设置test_10crop的参数:图片大小/按规则切割/随机翻转/正则化/totensor
                prep_dict[prep_config["name"]]["test"] = prep.image_test_10crop(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])
            else:
                prep_dict[prep_config["name"]]["test"] = prep.image_test(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])
               
    ## set loss
    class_criterion = nn.CrossEntropyLoss()  # loss函数
    loss_config = config["loss"]          # loss_config = {"name":'JAN', "trade_off":1.0 }
    transfer_criterion = loss.loss_dict[loss_config["name"]]  # transfer_criterion = JAN  # loss的内置函数
    if "params" not in loss_config:
        loss_config["params"] = {}

    ## prepare data
    dsets = {}
    dset_loaders = {}
    for data_config in config["data"]:
        # name设为key:source/target
        dsets[data_config["name"]] = {}
        dset_loaders[data_config["name"]] = {}
        ## image data
        if data_config["type"] == "image":
            # file.readlines(): 返回文件内的所有行,得到图片的路径
            # transform = 上面设置好的参数
            # 最终得到[image, label]的一个list
            dsets[data_config["name"]]["train"] = ImageList(open(data_config["list_path"]["train"]).readlines(), transform=prep_dict[data_config["name"]]["train"])
            # 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
            # dataset (Dataset) – 加载数据的数据集。
            # batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
            # shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
            # num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
            dset_loaders[data_config["name"]]["train"] = util_data.DataLoade
  • 7
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值