2021SC@SDUSC
最后对train.py中的main函数进行分析:
首先是读取数据并且建立数据加载器
def main():
global best_acc
train_labeled_set, train_unlabeled_set, val_set, test_set, n_labels = get_data(
args.data_path, args.n_labeled, args.un_labeled, model=args.model, train_aug=args.train_aug)
labeled_trainloader = Data.DataLoader(
dataset=train_labeled_set, batch_size=args.batch_size, shuffle=True)
unlabeled_trainloader = Data.DataLoader(
dataset=train_unlabeled_set, batch_size=args.batch_size_u, shuffle=True)
val_loader = Data.DataLoader(
dataset=val_set, batch_size=512, shuffle=False)
test_loader = Data.DataLoader(
dataset=test_set, batch_size=512, shuffle=False)
定义模型并且定义优化器
model = MixText(n_labels, args.mix_option).cuda()
model = nn.DataParallel(model)
optimizer = AdamW(
[