The State Of Knowledge Distillation For Classification Tasks 代码笔记

import argparse
from pathlib import Path

from distillers import *
from data_loader import get_cifar
from models.model_factory import create_model
from trainer import BaseTrainer, KDTrainer, MultiTrainer, TripletTrainer
from plot import plot_results
import util

BATCH_SIZE = 128
TESTFOLDER = "results"
USE_ID = True


def parse_arguments(): # 这里是函数的配置函数,用来指定命令行的参数例如:python3 evaluate_kd.py --epochs 200 --teacher resnet18 --student resnet8 --dataset cifar10 --teacher-checkpoint pretrained/resnet18_cifar10.pth --mode nokd kd
    parser = argparse.ArgumentParser(
        description="Test parameters")
    parser.add_argument("--epochs", default=100, type=int,
                        help="number of total epochs to run")
    parser.add_argument("--dataset", default="cifar100", type=str,
                        help="dataset. can be either cifar10 or cifar100")
    parser.add_argument("--batch-size", default=BATCH_SIZE,
                        type=int, help="batch_size")
    parser.add_argument("--learning-rate", default=0.1,
                        type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9,
                        type=float, help="SGD momentum")
    parser.add_argument("--weight-decay", default=5e-4,
                        type=float, help="SGD weight decay (default: 5e-4)")
    parser.add_argument("--teacher", default="WRN22_4", type=str,
                        dest="t_name", help="teacher student name")
    parser.add_argument("--student", "--model", default="resnet18",
                        dest="s_name", type=str, help="teacher student name")
    parser.add_argument("--optimizer", default="sgd",
                        dest="optimizer", type=str,
                        help="Which optimizer to use")
    parser.add_argument("--scheduler", default="multisteplr",
                        dest="scheduler", type=str,
                        help="Which scheduler to use")
    parser.add_argument("--teacher-checkpoint", default="",
                        dest="t_checkpoint", type=str,
                        help="optional pretrained checkpoint for teacher")
    parser.add_argument("--mode", default=["KD"], dest="modes",
                        type=str, nargs='+',
                        help="What type of distillation to use")
    parser.add_argument("--results-dir", default=TESTFOLDER,
                        dest="results_dir", type=str,
                        help="Where all results are collected")
    args = parser.parse_args()
    return args


def setup_teacher(t_name, params):
    # Teacher Model
    num_classes = params["num_classes"]
    t_net = create_model(t_name, num_classes, params["device"])
    teacher_config = params.copy()
    teacher_config["test_name"] = t_name + "_teacher"

    if params["t_checkpoint"]:
        # Just validate the performance
        print("---------- Loading Teacher -------")
        best_teacher = params["t_checkpoint"]
    else:
        # Teacher training
        print("---------- Training Teacher -------")
        teacher_trainer = BaseTrainer(t_net, config=teacher_config)
        teacher_trainer.train()
        best_teacher = teacher_trainer.best_model_file

    # reload and get the best model
    t_net = util.load_checkpoint(t_net, best_teacher)
    teacher_trainer = BaseTrainer(t_net, config=teacher_config)
    best_t_acc = teacher_trainer.validate()

    # also save this information in a csv file for plotting
    name = teacher_config["test_name"] + "_val"
    acc_file_name = params["results_dir"].joinpath(f"{name}.csv")
    with acc_file_name.open("w+") as acc_file:
        acc_file.write("Training Loss,Validation Loss\n")
        for _ in range(params["epochs"]):
            acc_file.write(f"0.0,{best_t_acc}\n")
    return t_net, best_teacher, best_t_acc


def setup_student(s_name, params):
    # Student Model
    num_classes = params["num_classes"]
    s_net = create_model(s_name, num_classes, params["device"])
    return s_net


def freeze_teacher(t_net):
    # freeze the layers of the teacher
    for param in t_net.parameters():
        param.requires_grad = False
    # set the teacher net into evaluation mode
    t_net.eval()
    return t_net


def test_nokd(s_net, t_net, params):
    print("---------- Training NOKD -------")
    nokd_config = params.copy()
    nokd_trainer = BaseTrainer(s_net, config=nokd_config)
    best_acc = nokd_trainer.train()
    return best_acc


def test_kd(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    print("---------- Training KD -------")
    kd_config = params.copy()
    kd_trainer = KDTrainer(s_net, t_net=t_net, config=kd_config)
    best_acc = kd_trainer.train()
    return best_acc


def test_triplet(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    print("---------- Training TRIPLET -------")
    kd_config = params.copy()
    kd_trainer = TripletTrainer(s_net, t_net=t_net, config=kd_config)
    best_acc = kd_trainer.train()
    return best_acc


def test_multikd(s_net, t_net1, params):
    t_net1 = freeze_teacher(t_net1)
    print("---------- Training MULTIKD -------")
    kd_config = params.copy()
    params["t2_name"] = "WRN22_4"
    t_net2 = create_model(
        params["t2_name"], params["num_classes"], params["device"])
    t_net2 = util.load_checkpoint(
        t_net2, "pretrained/WRN22_4_cifar10.pth")
    t_net2 = freeze_teacher(t_net2)

    params["t3_name"] = "resnet18"
    t_net3 = create_model(
        params["t3_name"], params["num_classes"], params["device"])
    t_net3 = util.load_checkpoint(
        t_net3, "pretrained/resnet18_cifar10.pth")
    t_net3 = freeze_teacher(t_net3)

    t_nets = [t_net1, t_net2]
    kd_trainer = MultiTrainer(s_net, t_nets=t_nets, config=kd_config)
    best_acc = kd_trainer.train()
    return best_acc


def test_takd(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    num_classes = params["num_classes"]
    # Arguments specifically for the teacher assistant approach
    params["ta_name"] = "resnet20"
    ta_model = create_model(
        params["ta_name"], num_classes, params["device"])
    best_acc = run_takd_distillation(s_net, ta_model, t_net, **params)
    return best_acc


def test_uda(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    best_acc = run_uda_distillation(s_net, t_net, **params)
    return best_acc


def test_ab(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    best_acc = run_ab_distillation(s_net, t_net, **params)
    return best_acc


def test_rkd(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    best_acc = run_rkd_distillation(s_net, t_net, **params)
    return best_acc


def test_pkd(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    best_acc = run_pkd_distillation(s_net, t_net, **params)
    return best_acc


def test_oh(s_net, t_net, params):
    # do not freeze the teacher in oh distillation
    best_acc = run_oh_distillation(s_net, t_net, **params)
    return best_acc


def test_fd(s_net, t_net, params):
    t_net = freeze_teacher(t_net)
    best_acc = run_fd_distillation(s_net, t_net, **params)
    return best_acc


def test_allkd(s_name, params):
    teachers = ["resnet8", "resnet14", "resnet20", "resnet26",
                "resnet32", "resnet44", "resnet56",
                # "resnet34", "resnet50", "resnet101", "resnet152",
                ]
    accs = {}
    for t_name in teachers:
        params_t = params.copy()
        params_t["teacher_name"] = t_name
        t_net, best_teacher, best_t_acc = setup_teacher(t_name, params_t)
        t_net = util.load_checkpoint(t_net, best_teacher, params_t["device"])
        t_net = freeze_teacher(t_net)
        s_net = setup_student(s_name, params_t)
        params_t["test_name"] = f"{s_name}_{t_name}"
        params_t["results_dir"] = params_t["results_dir"].joinpath("allkd")
        util.check_dir(params_t["results_dir"])
        best_acc = test_kd(s_net, t_net, params_t)
        accs[t_name] = (best_t_acc, best_acc)

    best_acc = 0
    best_t_acc = 0
    for t_name, acc in accs.items():
        if acc[0] > best_t_acc:
            best_t_acc = acc[0]
        if acc[1] > best_acc:
            best_acc = acc[1]
        print(f"Best results teacher {t_name}: {acc[0]}")
        print(f"Best results for {s_name}: {acc[1]}")

    return best_t_acc, best_acc


def test_kdparam(s_net, t_net, params):
    temps = [1, 5, 10, 15, 20]
    alphas = [0.1, 0.4, 0.5, 0.7, 1.0]
    param_pairs = [(a, T) for T in temps for a in alphas]
    accs = {}

    for alpha, T, in param_pairs:
        params_s = params.copy()
        params_s["lambda_student"] = alpha
        params_s["T_student"]: T
        s_name = params_s["student_name"]
        s_net = setup_student(s_name, params_s)
        params_s["test_name"] = f"{s_name}_{T}_{alpha}"
        print(f"Testing {s_name} with alpha {alpha} and T {T}.")
        best_acc = test_kd(s_net, t_net, params_s)
        accs[params_s["test_name"]] = (alpha, T, best_acc)

    best_kdparam_acc = 0
    for test_name, acc in accs.items():
        alpha = acc[0]
        T = acc[1]
        kd_acc = acc[2]
        if acc[2] > best_kdparam_acc:
            best_kdparam_acc = acc[2]
        print(f"Best results for {s_name} with a {alpha} and T {T}: {kd_acc}")

    return best_kdparam_acc


def run_benchmarks(modes, params, s_name, t_name):
    results = {}

    # if we test allkd we do not need to train an individual teacher
    if "allkd" in modes:
        best_t_acc, results["allkd"] = test_allkd(s_name, params)
        modes.remove("allkd")
    else:
        t_net, best_teacher, best_t_acc = setup_teacher(t_name, params)

    for mode in modes:
        mode = mode.lower()
        params_s = params.copy()
        # reset the teacher
        t_net = util.load_checkpoint(t_net, best_teacher, params["device"])

        # load the student and create a results directory for the mode
        s_net = setup_student(s_name, params)
        params_s["test_name"] = s_name
        params_s["results_dir"] = params_s["results_dir"].joinpath(mode)
        util.check_dir(params_s["results_dir"])
        # start the test
        try:
            run_test = globals()[f"test_{mode}"]
            results[mode] = run_test(s_net, t_net, params_s)
        except KeyError:
            raise RuntimeError(f"Training mode {mode} not supported!")

    # Dump the overall results
    print(f"Best results teacher {t_name}: {best_t_acc}")
    for name, acc in results.items():
        print(f"Best results for {s_name} with {name} method: {acc}")


def start_evaluation(args):
    device = util.setup_torch() # 指定是GPU设备还是CPU设备
    num_classes = 100 if args.dataset == "cifar100" else 10
    train_loader, test_loader = get_cifar(num_classes,
                                          batch_size=args.batch_size)

    # for benchmarking, decided whether we want to use unique test folders
    if USE_ID:
        test_id = util.generate_id()
    else:
        test_id = ""
    results_dir = Path(args.results_dir).joinpath(test_id)
    results_dir = Path(results_dir).joinpath(args.dataset)
    util.check_dir(results_dir)

    # Parsing arguments and prepare settings for training
    params = {
        "epochs": args.epochs,
        "modes": args.modes,
        "t_checkpoint": args.t_checkpoint,
        "results_dir": results_dir,
        "train_loader": train_loader,
        "test_loader": test_loader,
        "batch_size": args.batch_size,
        # model configuration
        "device": device,
        "teacher_name": args.t_name,
        "student_name": args.s_name,
        "num_classes": num_classes,
        # hyperparameters
        "weight_decay": args.weight_decay,
        "learning_rate": args.learning_rate,
        "momentum": args.momentum,
        "sched": args.scheduler,
        "optim": args.optimizer,
        # fixed knowledge distillation parameters
        "lambda_student": 0.5,
        "T_student": 5,
    }
    test_conf_name = results_dir.joinpath("test_config.json")
    util.dump_json_config(test_conf_name, params)
    run_benchmarks(args.modes, params, args.s_name, args.t_name)
    plot_results(results_dir, test_id=test_id)


if __name__ == "__main__":
    ARGS = parse_arguments()
    start_evaluation(ARGS)

 生成结果文件夹的名字

def generate_id():
    sw_id = "".join(random.choice("".join([random.choice(
        string.ascii_letters + string.digits)
        for ch in range(4)])) for _ in range(4))
    return sw_id

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值