FATE —— 数据集方法源码

代码主要为base.py,table.py,nlp_tokenizer.py,image.py

base.py
from torch.utils.data import Dataset as Dataset_
from federatedml.nn.backend.utils.common import ML_PATH
import importlib
import abc
import numpy as np


class Dataset(Dataset_):

    def __init__(self, **kwargs):
        super(Dataset, self).__init__()
        self._type = 'local'  # train/predict
        self._check = False
        self._generated_ids = None
        self.training = True

    @property
    def dataset_type(self):
        if not hasattr(self, '_type'):
            raise AttributeError(
                'type variable not exists, call __init__ of super class')
        return self._type

    @dataset_type.setter
    def dataset_type(self, val):
        self._type = val

    def has_dataset_type(self):
        return self.dataset_type

    def set_type(self, _type):
        self.dataset_type = _type

    def get_type(self):
        return self.dataset_type

    def has_sample_ids(self):

        # if not implement get_sample_ids, return False
        try:
            sample_ids = self.get_sample_ids()
        except NotImplementedError as e:
            return False
        except BaseException as e:
            raise e

        if sample_ids is None:
            return False
        else:
            if not self._check:
                assert isinstance(
                    sample_ids, list), 'get_sample_ids() must return a list contains str or integer'
                for id_ in sample_ids:
                    if (not isinstance(id_, str)) and (not isinstance(id_, int)):
                        raise RuntimeError(
                            'get_sample_ids() must return a list contains str or integer: got id of type {}:{}'.format(
                                id_, type(id_)))
                assert len(sample_ids) == len(
                    self), 'sample id len:{} != dataset length:{}'.format(len(sample_ids), len(self))
                self._check = True
            return True

    def init_sid_and_getfunc(self, prefix: str = None):
        if prefix is not None:
            assert isinstance(
                prefix, str), 'prefix must be a str, but got {}'.format(prefix)
        else:
            prefix = self._type
        generated_ids = []
        for i in range(0, self.__len__()):
            generated_ids.append(prefix + '_' + str(i))
        self._generated_ids = generated_ids

        def get_func():
            return self._generated_ids
        self.get_sample_ids = get_func

    """
    Functions for users
    """

    def train(self, ):
        self.training = True

    def eval(self, ):
        self.training = False

    # Function to implemented

    @abc.abstractmethod
    def load(self, file_path):
        raise NotImplementedError(
            'You must implement load function so that Client can pass file-path to this '
            'class')

    def __getitem__(self, item):
        raise NotImplementedError()

    def __len__(self):
        raise NotImplementedError()

    def get_classes(self):
        raise NotImplementedError()

    def get_sample_ids(self):
        raise NotImplementedError()


class ShuffleWrapDataset(Dataset_):

    def __init__(self, dataset: Dataset, shuffle_seed=100):
        super(ShuffleWrapDataset, self).__init__()
        self.ds = dataset
        ids = self.ds.get_sample_ids()
        sort_idx = np.argsort(np.array(ids))
        assert isinstance(dataset, Dataset)
        self.idx = sort_idx
        if shuffle_seed is not None:
            np.random.seed(shuffle_seed)
            self.shuffled_idx = np.copy(self.idx)
            np.random.shuffle(self.shuffled_idx)
        else:
            self.shuffled_idx = np.copy(self.idx)
        self.idx_map = {k: v for k, v in zip(self.idx, self.shuffled_idx)}

    def train(self, ):
        self.ds.train()

    def eval(self, ):
        self.ds.eval()

    def __getitem__(self, item):
        return self.ds[self.idx_map[self.idx[item]]]

    def __len__(self):
        return len(self.ds)

    def __repr__(self):
        return self.ds.__repr__()

    def has_sample_ids(self):
        return self.ds.has_sample_ids()

    def set_shuffled_idx(self, idx_map: dict):
        self.shuffled_idx = np.array(list(idx_map.values()))
        self.idx_map = idx_map

    def get_sample_ids(self):
        ids = self.ds.get_sample_ids()
        return np.array(ids)[self.shuffled_idx].tolist()

    def get_classes(self):
        return self.ds.get_classes()


def get_dataset_class(dataset_module_name: str):

    if dataset_module_name.endswith('.py'):
        dataset_module_name = dataset_module_name.replace('.py', '')
    ds_modules = importlib.import_module(
        '{}.dataset.{}'.format(
            ML_PATH, dataset_module_name))
    try:

        for k, v in ds_modules.__dict__.items():
            if isinstance(v, type):
                if issubclass(v, Dataset) and v is not Dataset:
                    return v
        raise ValueError('Did not find any class in {}.py that is the subclass of Dataset class'.
                         format(dataset_module_name))
    except ValueError as e:
        raise 
table.py
import numpy as np
import pandas as pd
from federatedml.statistic.data_overview import with_weight
from federatedml.nn.dataset.base import Dataset
from federatedml.util import LOGGER


class TableDataset(Dataset):

    """
     A Table Dataset, load data from a give csv path, or transform FATE DTable
     Parameters
     ----------
     label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label
     feature_dtype dtype of feature, supports int, long, float, double
     label_dtype: dtype of label, supports int, long, float, double
     label_shape: list or tuple, the shape of label
     flatten_label: bool, flatten extracted label column or not, default is False
     """

    def __init__(
            self,
            label_col=None,
            feature_dtype='float',
            label_dtype='float',
            label_shape=None,
            flatten_label=False):

        super(TableDataset, self).__init__()
        self.with_label = True
        self.with_sample_weight = False
        self.features: np.ndarray = None
        self.label: np.ndarray = None
        self.sample_weights: np.ndarray = None
        self.origin_table: pd.DataFrame = pd.DataFrame()
        self.label_col = label_col
        self.f_dtype = self.check_dtype(feature_dtype)
        self.l_dtype = self.check_dtype(label_dtype)
        if label_shape is not None:
            assert isinstance(label_shape, tuple) or isinstance(
                label_shape, list), 'label shape is {}'.format(label_shape)
        self.label_shape = label_shape
        self.flatten_label = flatten_label

        # ids, match ids is for FATE match id system
        self.sample_ids = None
        self.match_ids = None

        if self.label_col is not None:
            assert isinstance(self.label_col, str) or isinstance(
                self.label_col, int), 'label columns parameter must be a str or an int'

    @staticmethod
    def check_dtype(dtype):

        if dtype is not None:
            avail = ['long', 'int', 'float', 'double']
            assert dtype in avail, 'available dtype is {}, but got {}'.format(
                avail, dtype)
            if dtype == 'long':
                return np.int64
            if dtype == 'int':
                return np.int32
            if dtype == 'float':
                return np.float32
            if dtype == 'double':
                return np.float64
        return dtype

    def __getitem__(self, item):

        if self.with_label:
            if self.with_sample_weight and self.training:
                return self.features[item], (self.label[item], self.sample_weights[item])
            else:
                return self.features[item], self.label[item]
        else:
            return self.features[item]

    def __len__(self):
        return len(self.origin_table)

    def load(self, file_path):

        if isinstance(file_path, str):
            self.origin_table = pd.read_csv(file_path)
        elif isinstance(file_path, pd.DataFrame):
            self.origin_table = file_path
        else:
            # if is FATE DTable, collect data and transform to array format
            data_inst = file_path
            self.with_sample_weight = with_weight(data_inst)
            LOGGER.info('collecting FATE DTable, with sample weight is {}'.format(self.with_sample_weight))
            header = data_inst.schema["header"]
            LOGGER.debug('input dtable header is {}'.format(header))
            data = list(data_inst.collect())
            data_keys = [key for (key, val) in data]
            data_keys_map = dict(zip(sorted(data_keys), range(len(data_keys))))

            keys = [None for idx in range(len(data_keys))]
            x_ = [None for idx in range(len(data_keys))]
            y_ = [None for idx in range(len(data_keys))]
            match_ids = {}
            sample_weights = [1 for idx in range(len(data_keys))]

            for (key, inst) in data:
                idx = data_keys_map[key]
                keys[idx] = key
                x_[idx] = inst.features
                y_[idx] = inst.label
                match_ids[key] = inst.inst_id
                if self.with_sample_weight:
                    sample_weights[idx] = inst.weight

            x_ = np.asarray(x_)
            y_ = np.asarray(y_)
            df = pd.DataFrame(x_)
            df.columns = header
            df['id'] = sorted(data_keys)
            df['label'] = y_
            # host data has no label, so this columns will all be None
            if df['label'].isna().all():
                df = df.drop(columns=['label'])

            self.origin_table = df
            self.sample_weights = np.array(sample_weights)
            self.match_ids = match_ids

        label_col_candidates = ['y', 'label', 'target']

        # automatically set id columns
        id_col_candidates = ['id', 'sid']
        for id_col in id_col_candidates:
            if id_col in self.origin_table:
                self.sample_ids = self.origin_table[id_col].values.tolist()
                self.origin_table = self.origin_table.drop(columns=[id_col])
                break

        # infer column name
        label = self.label_col
        if label is None:
            for i in label_col_candidates:
                if i in self.origin_table:
                    label = i
                    break
            if label is None:
                self.with_label = False
                LOGGER.warning(
                    'label default setting is "auto", but found no "y"/"label"/"target" in input table')
        else:
            if label not in self.origin_table:
                raise ValueError(
                    'label column {} not found in input table'.format(label))

        if self.with_label:
            self.label = self.origin_table[label].values
            self.features = self.origin_table.drop(columns=[label]).values

            if self.l_dtype:
                self.label = self.label.astype(self.l_dtype)

            if self.label_shape:
                self.label = self.label.reshape(self.label_shape)
            else:
                self.label = self.label.reshape((len(self.features), -1))

            if self.flatten_label:
                self.label = self.label.flatten()

        else:
            self.label = None
            self.features = self.origin_table.values

        if self.f_dtype:
            self.features = self.features.astype(self.f_dtype)

    def get_classes(self):
        if self.label is not None:
            return np.unique(self.label).tolist()
        else:
            raise ValueError(
                'no label found, please check if self.label is set')

    def get_sample_ids(self):
        return self.sample_ids

    def get_match_ids(self):
        return self.match_ids
nlp_tokenizer.py
from federatedml.nn.dataset.base import Dataset
import pandas as pd
import torch as t
from transformers import BertTokenizerFast
import os
import numpy as np

# avoid tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class TokenizerDataset(Dataset):
    """
    A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices
    using BertTokenizer from transformers library,
    see https://huggingface.co/docs/transformers/model_doc/bert?highlight=berttokenizer for details of BertTokenizer
    Parameters
    ----------
    truncation bool, truncate word sequence to 'text_max_length'
    text_max_length int, max length of word sequences
    tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local
                                transformer tokenizer folder
    return_label bool, return label or not, this option is for host dataset, when running hetero-NN
    """

    def __init__(self, truncation=True, text_max_length=128,
                 tokenizer_name_or_path="bert-base-uncased",
                 return_label=True):

        super(TokenizerDataset, self).__init__()
        self.text = None
        self.word_idx = None
        self.label = None
        self.tokenizer = None
        self.sample_ids = None
        self.truncation = truncation
        self.max_length = text_max_length
        self.with_label = return_label
        self.tokenizer_name_or_path = tokenizer_name_or_path

    def load(self, file_path):

        tokenizer = BertTokenizerFast.from_pretrained(
            self.tokenizer_name_or_path)
        self.tokenizer = tokenizer
        self.text = pd.read_csv(file_path)
        text_list = list(self.text.text)
        self.word_idx = tokenizer(
            text_list,
            padding=True,
            return_tensors='pt',
            truncation=self.truncation,
            max_length=self.max_length)['input_ids']
        if self.with_label:
            self.label = t.Tensor(self.text.label).detach().numpy()
            self.label = self.label.reshape((len(self.word_idx), -1))
        del tokenizer  # avoid tokenizer parallelism

        if 'id' in self.text:
            self.sample_ids = self.text['id'].values.tolist()

    def get_classes(self):
        return np.unique(self.label).tolist()

    def get_vocab_size(self):
        return self.tokenizer.vocab_size

    def get_sample_ids(self):
        return self.sample_ids

    def __getitem__(self, item):
        if self.with_label:
            return self.word_idx[item], self.label[item]
        else:
            return self.word_idx[item]

    def __len__(self):
        return len(self.word_idx)

    def __repr__(self):
        return self.tokenizer.__repr__()
image.py
import torch
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np


class ImageDataset(Dataset):

    """
    A basic Image Dataset built on pytorch ImageFolder, supports simple image transform
    Given a folder path, ImageDataset will load images from this folder, images in this
    folder need to be organized in a Torch-ImageFolder format, see
    https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html for details.
    Image name will be automatically taken as the sample id.
    Parameters
    ----------
    center_crop : bool, use center crop transformer
    center_crop_shape: tuple or list
    generate_id_from_file_name: bool, whether to take image name as sample id
    file_suffix: str, default is '.jpg', if generate_id_from_file_name is True, will remove this suffix from file name,
                 result will be the sample id
    return_label: bool, return label or not, this option is for host dataset, when running hetero-NN
    float64: bool, returned image tensors will be transformed to double precision
    label_dtype: str, long, float, or double, the dtype of return label
    """

    def __init__(self, center_crop=False, center_crop_shape=None,
                 generate_id_from_file_name=True, file_suffix='.jpg',
                 return_label=True, float64=False, label_dtype='long'):

        super(ImageDataset, self).__init__()
        self.image_folder: ImageFolder = None
        self.center_crop = center_crop
        self.size = center_crop_shape
        self.return_label = return_label
        self.generate_id_from_file_name = generate_id_from_file_name
        self.file_suffix = file_suffix
        self.float64 = float64
        self.dtype = torch.float32 if not self.float64 else torch.float64
        avail_label_type = ['float', 'long', 'double']
        self.sample_ids = None
        assert label_dtype in avail_label_type, 'available label dtype : {}'.format(
            avail_label_type)
        if label_dtype == 'double':
            self.label_dtype = torch.float64
        elif label_dtype == 'long':
            self.label_dtype = torch.int64
        else:
            self.label_dtype = torch.float32

    def load(self, folder_path):

        # read image from folders
        if self.center_crop:
            transformer = transforms.Compose(
                [transforms.CenterCrop(size=self.size), transforms.ToTensor()])
        else:
            transformer = transforms.Compose([transforms.ToTensor()])

        if folder_path.endswith('/'):
            folder_path = folder_path[: -1]
        image_folder_path = folder_path
        folder = ImageFolder(root=image_folder_path, transform=transformer)
        self.image_folder = folder

        if self.generate_id_from_file_name:
            # use image name as its sample id
            file_name = self.image_folder.imgs
            ids = []
            for name in file_name:
                sample_id = name[0].split(
                    '/')[-1].replace(self.file_suffix, '')
                ids.append(sample_id)
            self.sample_ids = ids

    def __getitem__(self, item):
        if self.return_label:
            item = self.image_folder[item]
            return item[0].type(
                self.dtype), torch.tensor(
                item[1]).type(
                self.label_dtype)
        else:
            return self.image_folder[item][0].type(self.dtype)

    def __len__(self):
        return len(self.image_folder)

    def __repr__(self):
        return self.image_folder.__repr__()

    def get_classes(self):
        return np.unique(self.image_folder.targets).tolist()

    def get_sample_ids(self):
        return self.sample_ids


if __name__ == '__main__':
    pass
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值