2021-11-27

 2021SC@SDUSC

本次分析prepare/generate_prepare模块

这是初始化,定义不同实验的batch数量

TRAIN_NUM_BATCHES = int(sys.argv[2])

DEV_NUM_BATCHES = int(sys.argv[3])

TEST_NUM_BATCHES = int(sys.argv[4])

generate_bash这个函数为训练实验创建了5个文件,为开发实验创建了1个文件,为测试实验创建了1个文件。

def generate_bash():

    dataset = "./amr_data/amr_2.0/csqa"

    concept_seed = sys.argv[5]  # question_amr, question_token

    with open("cmd_extract_train1.sh", 'w') as f:

        for i in range(1, 11):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/train.pred_%d.txt --nprocessors 2 --concept_seed %s &\n" %(dataset, dataset, i, concept_seed))

        f.write('wait')

    with open("cmd_extract_train2.sh", 'w') as f:

        for i in range(11, 21):

            f.write(

                "python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/train.pred_%d.txt --nprocessors 2 --concept_seed %s &\n" % (dataset, dataset, i, concept_seed))

        f.write('wait')

    with open("cmd_extract_train3.sh", 'w') as f:

        for i in range(21, 31):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/train.pred_%d.txt --nprocessors 2 --concept_seed %s &\n" %(dataset, dataset, i, concept_seed))

        f.write('wait')

    #

    with open("cmd_extract_train4.sh", 'w') as f:

        for i in range(31, 41):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/train.pred_%d.txt --nprocessors 2 --concept_seed %s &\n" %(dataset, dataset, i, concept_seed))

        f.write('wait')

    with open("cmd_extract_train5.sh", 'w') as f:

        for i in range(41, 51):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/train.pred_%d.txt --nprocessors 2 --concept_seed %s &\n" %(dataset, dataset, i, concept_seed))

        f.write('wait')

    with open("cmd_extract_dev.sh", 'w') as f:

        for i in range(1, DEV_NUM_BATCHES+1):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/dev.pred_%d.txt --nprocessors 1 --concept_seed %s &\n" %(dataset, dataset, i,concept_seed))

        f.write('wait')

    with open("cmd_extract_test.sh", 'w') as f:

        for i in range(1, TEST_NUM_BATCHES+1):

            f.write("python3 extract_property.py --train_data %s/train.pred.txt --amr_files %s/test.pred_%d.txt --nprocessors 1 --concept_seed %s &\n" %(dataset, dataset, i, concept_seed))

        f.write('wait')

函数copy_files的功能是完成文件之间的复制,并给出异常处理。

def copy_files(source, destination):

    # importing shutil module

    # Copy the content of

    # source to destination

    try:

        shutil.copyfile(source, destination)

        print("File copied successfully.")

        # If source and destination are same

    except shutil.SameFileError:

        print("Source and destination represents the same file.")

        # If destination is a directory.

    except IsADirectoryError:

        print("Destination is a directory.")

        # If there is any permission issue

    except PermissionError:

        print("Permission denied.")

        # For other errors

    except:

        print("Error occurred while copying file.")

函数combine将文件中的字符组合为可索引的序列,分训练、开发、测试三种情况分别处理。

def combine():

    mode = sys.argv[5]

    PATH = '/mnt/cn_data/amr_2.0/csqa/'

    if mode == 'dev':

        with open(PATH + "dev_pred_cn_extended_real_final.json", 'w') as fj:

            fj.write('[')

            for i in range(1, DEV_NUM_BATCHES+1):

                print('i th batch', i)

                try_parse(PATH + "dev.pred_%d_cn_extended_final.json" % i)

                print('done_parsing')

                with open(PATH + "dev.pred_%d_cn_extended_final.json" % i, 'rb') as fp:

                    objects = ijson.items(fp, 'item')

                    for i, line in enumerate(objects):

                        json.dump(line, fj)

                        fj.write(' ,')

        source = PATH + 'dev_pred_cn_extended_real_final.json'

        with open(source, 'rb+') as fj_filehandle:

            # Destination path

            destination = source[:source.index('final.json')] + 'final_original.json'

            copy_files(source, destination)

            fj_filehandle.seek(-1, os.SEEK_END)

            fj_filehandle.truncate()

        open(source, 'a').write("]")

    elif mode == 'test':

        with open(PATH + "test_pred_cn_extended_real_final.json", 'w') as fj:

            fj.write('[')

            for i in range(1, TEST_NUM_BATCHES+1):

                print('i th batch', i)

                try_parse(PATH + "test.pred_%d_cn_extended_final.json" % i)

                print('done_parsing')

                with open(PATH + "test.pred_%d_cn_extended_final.json" % i, 'rb') as fp:

                    objects = ijson.items(fp, 'item')

                    for i, line in enumerate(objects):

                        json.dump(line, fj)

                        fj.write(' ,')

        source = PATH + 'test_pred_cn_extended_real_final.json'

        with open(source, 'rb+') as fj_filehandle:

            # Destination path

            destination = source[:source.index('final.json')] + 'final_original.json'

            copy_files(source, destination)

            fj_filehandle.seek(-1, os.SEEK_END)

            fj_filehandle.truncate()

        open(source, 'a').write("]")

    else:

        with open(PATH + "train_pred_cn_extended_real_final.json", 'w') as fj:

            fj.write('[')

            for i in range(1, TRAIN_NUM_BATCHES + 1):

                print('i th batch', i)

                try_parse(PATH + "train.pred_%d_cn_extended_final.json" % i)

                print('done_parsing')

                with open(PATH + "train.pred_%d_cn_extended_final.json" % i, 'rb') as fp:

                    objects = ijson.items(fp, 'item')

                    for i, line in enumerate(objects):

                        json.dump(line, fj)

                        fj.write(' ,')

        source = PATH + 'train_pred_cn_extended_real_final.json'

        with open(source, 'rb+') as fj_filehandle:

            # Destination path

            destination = source[:source.index('final.json')] + 'final_original.json'

            copy_files(source, destination)

            fj_filehandle.seek(-1, os.SEEK_END)

            fj_filehandle.truncate()

        open(source, 'a').write("]")

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值