【DKN】(七)dataset.py【未完】

内容

里面有的函数在这里https://blog.csdn.net/qq_35222729/article/details/119882362


try:
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
    print(f"{model_name} not included!")
    exit()


class BaseDataset(Dataset):
    def __init__(self, behaviors_path, news_path, roberta_embedding_dir):
        super(BaseDataset, self).__init__()
        assert all(attribute in [
            'category', 'subcategory', 'title', 'abstract', 'title_entities',
            'abstract_entities', 'title_roberta', 'title_mask_roberta',
            'abstract_roberta', 'abstract_mask_roberta'
        ] for attribute in config.dataset_attributes['news'])  #数据集的属性应该在这些属性中
        assert all(attribute in ['user', 'clicked_news_length'] #同上
                   for attribute in config.dataset_attributes['record'])

        self.behaviors_parsed = pd.read_table(behaviors_path) #读入我们的行为并处理
        self.news_parsed = pd.read_table(  #
            news_path,
            index_col='id',
            usecols=['id'] + config.dataset_attributes['news'],
            converters={
                attribute: literal_eval   #对某些列执行literal-eval,将某些列转变为原类型,脱层
                for attribute in set(config.dataset_attributes['news']) & set([
                    'title', 'abstract', 'title_entities', 'abstract_entities',
                    'title_roberta', 'title_mask_roberta', 'abstract_roberta',
                    'abstract_mask_roberta'
                ])
            })
        self.news_id2int = {x: i for i, x in enumerate(self.news_parsed.index)}
        self.news2dict = self.news_parsed.to_dict('index')  
        for key1 in self.news2dict.keys():
            for key2 in self.news2dict[key1].keys():
                self.news2dict[key1][key2] = torch.tensor(
                    self.news2dict[key1][key2])
        padding_all = {
            'category': 0,
            'subcategory': 0,
            'title': [0] * config.num_words_title,
            'abstract': [0] * config.num_words_abstract,
            'title_entities': [0] * config.num_words_title,
            'abstract_entities': [0] * config.num_words_abstract,
            'title_roberta': [0] * config.num_words_title,
            'title_mask_roberta': [0] * config.num_words_title,
            'abstract_roberta': [0] * config.num_words_abstract,
            'abstract_mask_roberta': [0] * config.num_words_abstract
        }
        for key in padding_all.keys():
            padding_all[key] = torch.tensor(padding_all[key])

        self.padding = {
            k: v
            for k, v in padding_all.items()
            if k in config.dataset_attributes['news']
        }

    def _news2dict(self, id):
        ret = self.news2dict[id]
        if model_name == 'Exp2' and not config.fine_tune:
            for k in set(config.dataset_attributes['news']) & set(
                ['title', 'abstract']):
                ret[k] = self.roberta_embedding[k][self.news_id2int[id]]
        return ret

    def __len__(self):
        return len(self.behaviors_parsed)

    def __getitem__(self, idx):   #返回单个item
        item = {}
        row = self.behaviors_parsed.iloc[idx]
        if 'user' in config.dataset_attributes['record']:
            item['user'] = row.user
        item["clicked"] = list(map(int, row.clicked.split()))
        item["candidate_news"] = [
            self._news2dict(x) for x in row.candidate_news.split()
        ]
        item["clicked_news"] = [
            self._news2dict(x)
            for x in row.clicked_news.split()[:config.num_clicked_news_a_user]
        ]
        if 'clicked_news_length' in config.dataset_attributes['record']:
            item['clicked_news_length'] = len(item["clicked_news"])
        repeated_times = config.num_clicked_news_a_user - \
            len(item["clicked_news"])
        assert repeated_times >= 0
        item["clicked_news"] = [self.padding
                                ] * repeated_times + item["clicked_news"]

        return item

补充

1. ast.literal_eval

Python中,如果要将字符串型的list,tuple,dict转变成原有的类型呢?这个时候你自然会想到eval. eval函数在Python中做数据类型的转换还是很有用的。它的作用就是把数据还原成它本身或者是能够转化成的数据类型
string <=> list

In [1]: s = '[1, 2, 3, 4]'

In [2]: l = eval(s)

In [3]: s
Out[3]: '[1, 2, 3, 4]'

In [4]: l
Out[4]: [1, 2, 3, 4]

In [5]: type(s)
Out[5]: str

In [6]: type(l)
Out[6]: list
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值