目录
main函数
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Transfer Learning')
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--source', type=str, nargs='?', default='amazon', help="source data")
parser.add_argument('--target', type=str, nargs='?', default='webcam', help="target data")
parser.add_argument('--loss_name', type=str, nargs='?', default='JAN', help="loss name")
parser.add_argument('--tradeoff', type=float, nargs='?', default=1.0, help="tradeoff")
parser.add_argument('--using_bottleneck', type=int, nargs='?', default=1, help="whether to use bottleneck")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
config = {}
config["num_iterations"] = 20000
config["test_interval"] = 500
config["prep"] = [{"name":"source", "type":"image", "test_10crop":True, "resize_size":256, "crop_size":224},
{"name":"target", "type":"image", "test_10crop":True, "resize_size":256, "crop_size":224}]
config["loss"] = {"name":args.loss_name, "trade_off":args.tradeoff }
config["data"] = [{"name":"source", "type":"image", "list_path":{"train":"../data/office/"+args.source+"_list.txt"}, "batch_size":{"train":36, "test":4} },
{"name":"target", "type":"image", "list_path":{"train":"../data/office/"+args.target+"_list.txt"}, "batch_size":{"train":36, "test":4} }]
config["network"] = {"name":"ResNet50", "use_bottleneck":args.using_bottleneck, "bottleneck_dim":256}
config["optimizer"] = {"type":"SGD", "optim_params":{"lr":1.0, "momentum":0.9, "weight_decay":0.0005, "nesterov":True},
"lr_type":"inv", "lr_param":{"init_lr":0.0003, "gamma":0.0003, "power":0.75} }
print(config["loss"])
transfer_classification(config)
transfer_classification()
def transfer_classification(config):
## set pre-process
prep_dict = {}
for prep_config in config["prep"]:
prep_dict[prep_config["name"]] = {} # pre_dict["source"] = {}
if prep_config["type"] == "image": # True
# prep_dict["source"]["test_10crop"] = True
prep_dict[prep_config["name"]]["test_10crop"] = prep_config["test_10crop"]
# prep_dict["source"]["train"] = prep.image_train(resize_size = 256, crop_size = 224)
# 把test的参数设置好:图片大小/随机切割/随机反转/正则化/totensor
prep_dict[prep_config["name"]]["train"] = prep.image_train(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])
if prep_config["test_10crop"]: # True
# 设置test_10crop的参数:图片大小/按规则切割/随机翻转/正则化/totensor
prep_dict[prep_config["name"]]["test"] = prep.image_test_10crop(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])
else:
prep_dict[prep_config["name"]]["test"] = prep.image_test(resize_size=prep_config["resize_size"], crop_size=prep_config["crop_size"])
## set loss
class_criterion = nn.CrossEntropyLoss() # loss函数
loss_config = config["loss"] # loss_config = {"name":'JAN', "trade_off":1.0 }
transfer_criterion = loss.loss_dict[loss_config["name"]] # transfer_criterion = JAN # loss的内置函数
if "params" not in loss_config:
loss_config["params"] = {}
## prepare data
dsets = {}
dset_loaders = {}
for data_config in config["data"]:
# name设为key:source/target
dsets[data_config["name"]] = {}
dset_loaders[data_config["name"]] = {}
## image data
if data_config["type"] == "image":
# file.readlines(): 返回文件内的所有行,得到图片的路径
# transform = 上面设置好的参数
# 最终得到[image, label]的一个list
dsets[data_config["name"]]["train"] = ImageList(open(data_config["list_path"]["train"]).readlines(), transform=prep_dict[data_config["name"]]["train"])
# 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
# dataset (Dataset) – 加载数据的数据集。
# batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
# shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
# num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
dset_loaders[data_config["name"]]["train"] = util_data.DataLoade