代码主要为base.py,table.py,nlp_tokenizer.py,image.py
base.py
from torch.utils.data import Dataset as Dataset_
from federatedml.nn.backend.utils.common import ML_PATH
import importlib
import abc
import numpy as np
class Dataset(Dataset_):
def __init__(self, **kwargs):
super(Dataset, self).__init__()
self._type = 'local' # train/predict
self._check = False
self._generated_ids = None
self.training = True
@property
def dataset_type(self):
if not hasattr(self, '_type'):
raise AttributeError(
'type variable not exists, call __init__ of super class')
return self._type
@dataset_type.setter
def dataset_type(self, val):
self._type = val
def has_dataset_type(self):
return self.dataset_type
def set_type(self, _type):
self.dataset_type = _type
def get_type(self):
return self.dataset_type
def has_sample_ids(self):
# if not implement get_sample_ids, return False
try:
sample_ids = self.get_sample_ids()
except NotImplementedError as e:
return False
except BaseException as e:
raise e
if sample_ids is None:
return False
else:
if not self._check:
assert isinstance(
sample_ids, list), 'get_sample_ids() must return a list contains str or integer'
for id_ in sample_ids:
if (not isinstance(id_, str)) and (not isinstance(id_, int)):
raise RuntimeError(
'get_sample_ids() must return a list contains str or integer: got id of type {}:{}'.format(
id_, type(id_)))
assert len(sample_ids) == len(
self), 'sample id len:{} != dataset length:{}'.format(len(sample_ids), len(self))
self._check = True
return True
def init_sid_and_getfunc(self, prefix: str = None):
if prefix is not None:
assert isinstance(
prefix, str), 'prefix must be a str, but got {}'.format(prefix)
else:
prefix = self._type
generated_ids = []
for i in range(0, self.__len__()):
generated_ids.append(prefix + '_' + str(i))
self._generated_ids = generated_ids
def get_func():
return self._generated_ids
self.get_sample_ids = get_func
"""
Functions for users
"""
def train(self, ):
self.training = True
def eval(self, ):
self.training = False
# Function to implemented
@abc.abstractmethod
def load(self, file_path):
raise NotImplementedError(
'You must implement load function so that Client can pass file-path to this '
'class')
def __getitem__(self, item):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
def get_classes(self):
raise NotImplementedError()
def get_sample_ids(self):
raise NotImplementedError()
class ShuffleWrapDataset(Dataset_):
def __init__(self, dataset: Dataset, shuffle_seed=100):
super(ShuffleWrapDataset, self).__init__()
self.ds = dataset
ids = self.ds.get_sample_ids()
sort_idx = np.argsort(np.array(ids))
assert isinstance(dataset, Dataset)
self.idx = sort_idx
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
self.shuffled_idx = np.copy(self.idx)
np.random.shuffle(self.shuffled_idx)
else:
self.shuffled_idx = np.copy(self.idx)
self.idx_map = {k: v for k, v in zip(self.idx, self.shuffled_idx)}
def train(self, ):
self.ds.train()
def eval(self, ):
self.ds.eval()
def __getitem__(self, item):
return self.ds[self.idx_map[self.idx[item]]]
def __len__(self):
return len(self.ds)
def __repr__(self):
return self.ds.__repr__()
def has_sample_ids(self):
return self.ds.has_sample_ids()
def set_shuffled_idx(self, idx_map: dict):
self.shuffled_idx = np.array(list(idx_map.values()))
self.idx_map = idx_map
def get_sample_ids(self):
ids = self.ds.get_sample_ids()
return np.array(ids)[self.shuffled_idx].tolist()
def get_classes(self):
return self.ds.get_classes()
def get_dataset_class(dataset_module_name: str):
if dataset_module_name.endswith('.py'):
dataset_module_name = dataset_module_name.replace('.py', '')
ds_modules = importlib.import_module(
'{}.dataset.{}'.format(
ML_PATH, dataset_module_name))
try:
for k, v in ds_modules.__dict__.items():
if isinstance(v, type):
if issubclass(v, Dataset) and v is not Dataset:
return v
raise ValueError('Did not find any class in {}.py that is the subclass of Dataset class'.
format(dataset_module_name))
except ValueError as e:
raise
table.py
import numpy as np
import pandas as pd
from federatedml.statistic.data_overview import with_weight
from federatedml.nn.dataset.base import Dataset
from federatedml.util import LOGGER
class TableDataset(Dataset):
"""
A Table Dataset, load data from a give csv path, or transform FATE DTable
Parameters
----------
label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label
feature_dtype dtype of feature, supports int, long, float, double
label_dtype: dtype of label, supports int, long, float, double
label_shape: list or tuple, the shape of label
flatten_label: bool, flatten extracted label column or not, default is False
"""
def __init__(
self,
label_col=None,
feature_dtype='float',
label_dtype='float',
label_shape=None,
flatten_label=False):
super(TableDataset, self).__init__()
self.with_label = True
self.with_sample_weight = False
self.features: np.ndarray = None
self.label: np.ndarray = None
self.sample_weights: np.ndarray = None
self.origin_table: pd.DataFrame = pd.DataFrame()
self.label_col = label_col
self.f_dtype = self.check_dtype(feature_dtype)
self.l_dtype = self.check_dtype(label_dtype)
if label_shape is not None:
assert isinstance(label_shape, tuple) or isinstance(
label_shape, list), 'label shape is {}'.format(label_shape)
self.label_shape = label_shape
self.flatten_label = flatten_label
# ids, match ids is for FATE match id system
self.sample_ids = None
self.match_ids = None
if self.label_col is not None:
assert isinstance(self.label_col, str) or isinstance(
self.label_col, int), 'label columns parameter must be a str or an int'
@staticmethod
def check_dtype(dtype):
if dtype is not None:
avail = ['long', 'int', 'float', 'double']
assert dtype in avail, 'available dtype is {}, but got {}'.format(
avail, dtype)
if dtype == 'long':
return np.int64
if dtype == 'int':
return np.int32
if dtype == 'float':
return np.float32
if dtype == 'double':
return np.float64
return dtype
def __getitem__(self, item):
if self.with_label:
if self.with_sample_weight and self.training:
return self.features[item], (self.label[item], self.sample_weights[item])
else:
return self.features[item], self.label[item]
else:
return self.features[item]
def __len__(self):
return len(self.origin_table)
def load(self, file_path):
if isinstance(file_path, str):
self.origin_table = pd.read_csv(file_path)
elif isinstance(file_path, pd.DataFrame):
self.origin_table = file_path
else:
# if is FATE DTable, collect data and transform to array format
data_inst = file_path
self.with_sample_weight = with_weight(data_inst)
LOGGER.info('collecting FATE DTable, with sample weight is {}'.format(self.with_sample_weight))
header = data_inst.schema["header"]
LOGGER.debug('input dtable header is {}'.format(header))
data = list(data_inst.collect())
data_keys = [key for (key, val) in data]
data_keys_map = dict(zip(sorted(data_keys), range(len(data_keys))))
keys = [None for idx in range(len(data_keys))]
x_ = [None for idx in range(len(data_keys))]
y_ = [None for idx in range(len(data_keys))]
match_ids = {}
sample_weights = [1 for idx in range(len(data_keys))]
for (key, inst) in data:
idx = data_keys_map[key]
keys[idx] = key
x_[idx] = inst.features
y_[idx] = inst.label
match_ids[key] = inst.inst_id
if self.with_sample_weight:
sample_weights[idx] = inst.weight
x_ = np.asarray(x_)
y_ = np.asarray(y_)
df = pd.DataFrame(x_)
df.columns = header
df['id'] = sorted(data_keys)
df['label'] = y_
# host data has no label, so this columns will all be None
if df['label'].isna().all():
df = df.drop(columns=['label'])
self.origin_table = df
self.sample_weights = np.array(sample_weights)
self.match_ids = match_ids
label_col_candidates = ['y', 'label', 'target']
# automatically set id columns
id_col_candidates = ['id', 'sid']
for id_col in id_col_candidates:
if id_col in self.origin_table:
self.sample_ids = self.origin_table[id_col].values.tolist()
self.origin_table = self.origin_table.drop(columns=[id_col])
break
# infer column name
label = self.label_col
if label is None:
for i in label_col_candidates:
if i in self.origin_table:
label = i
break
if label is None:
self.with_label = False
LOGGER.warning(
'label default setting is "auto", but found no "y"/"label"/"target" in input table')
else:
if label not in self.origin_table:
raise ValueError(
'label column {} not found in input table'.format(label))
if self.with_label:
self.label = self.origin_table[label].values
self.features = self.origin_table.drop(columns=[label]).values
if self.l_dtype:
self.label = self.label.astype(self.l_dtype)
if self.label_shape:
self.label = self.label.reshape(self.label_shape)
else:
self.label = self.label.reshape((len(self.features), -1))
if self.flatten_label:
self.label = self.label.flatten()
else:
self.label = None
self.features = self.origin_table.values
if self.f_dtype:
self.features = self.features.astype(self.f_dtype)
def get_classes(self):
if self.label is not None:
return np.unique(self.label).tolist()
else:
raise ValueError(
'no label found, please check if self.label is set')
def get_sample_ids(self):
return self.sample_ids
def get_match_ids(self):
return self.match_ids
nlp_tokenizer.py
from federatedml.nn.dataset.base import Dataset
import pandas as pd
import torch as t
from transformers import BertTokenizerFast
import os
import numpy as np
# avoid tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class TokenizerDataset(Dataset):
"""
A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices
using BertTokenizer from transformers library,
see https://huggingface.co/docs/transformers/model_doc/bert?highlight=berttokenizer for details of BertTokenizer
Parameters
----------
truncation bool, truncate word sequence to 'text_max_length'
text_max_length int, max length of word sequences
tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local
transformer tokenizer folder
return_label bool, return label or not, this option is for host dataset, when running hetero-NN
"""
def __init__(self, truncation=True, text_max_length=128,
tokenizer_name_or_path="bert-base-uncased",
return_label=True):
super(TokenizerDataset, self).__init__()
self.text = None
self.word_idx = None
self.label = None
self.tokenizer = None
self.sample_ids = None
self.truncation = truncation
self.max_length = text_max_length
self.with_label = return_label
self.tokenizer_name_or_path = tokenizer_name_or_path
def load(self, file_path):
tokenizer = BertTokenizerFast.from_pretrained(
self.tokenizer_name_or_path)
self.tokenizer = tokenizer
self.text = pd.read_csv(file_path)
text_list = list(self.text.text)
self.word_idx = tokenizer(
text_list,
padding=True,
return_tensors='pt',
truncation=self.truncation,
max_length=self.max_length)['input_ids']
if self.with_label:
self.label = t.Tensor(self.text.label).detach().numpy()
self.label = self.label.reshape((len(self.word_idx), -1))
del tokenizer # avoid tokenizer parallelism
if 'id' in self.text:
self.sample_ids = self.text['id'].values.tolist()
def get_classes(self):
return np.unique(self.label).tolist()
def get_vocab_size(self):
return self.tokenizer.vocab_size
def get_sample_ids(self):
return self.sample_ids
def __getitem__(self, item):
if self.with_label:
return self.word_idx[item], self.label[item]
else:
return self.word_idx[item]
def __len__(self):
return len(self.word_idx)
def __repr__(self):
return self.tokenizer.__repr__()
image.py
import torch
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
class ImageDataset(Dataset):
"""
A basic Image Dataset built on pytorch ImageFolder, supports simple image transform
Given a folder path, ImageDataset will load images from this folder, images in this
folder need to be organized in a Torch-ImageFolder format, see
https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html for details.
Image name will be automatically taken as the sample id.
Parameters
----------
center_crop : bool, use center crop transformer
center_crop_shape: tuple or list
generate_id_from_file_name: bool, whether to take image name as sample id
file_suffix: str, default is '.jpg', if generate_id_from_file_name is True, will remove this suffix from file name,
result will be the sample id
return_label: bool, return label or not, this option is for host dataset, when running hetero-NN
float64: bool, returned image tensors will be transformed to double precision
label_dtype: str, long, float, or double, the dtype of return label
"""
def __init__(self, center_crop=False, center_crop_shape=None,
generate_id_from_file_name=True, file_suffix='.jpg',
return_label=True, float64=False, label_dtype='long'):
super(ImageDataset, self).__init__()
self.image_folder: ImageFolder = None
self.center_crop = center_crop
self.size = center_crop_shape
self.return_label = return_label
self.generate_id_from_file_name = generate_id_from_file_name
self.file_suffix = file_suffix
self.float64 = float64
self.dtype = torch.float32 if not self.float64 else torch.float64
avail_label_type = ['float', 'long', 'double']
self.sample_ids = None
assert label_dtype in avail_label_type, 'available label dtype : {}'.format(
avail_label_type)
if label_dtype == 'double':
self.label_dtype = torch.float64
elif label_dtype == 'long':
self.label_dtype = torch.int64
else:
self.label_dtype = torch.float32
def load(self, folder_path):
# read image from folders
if self.center_crop:
transformer = transforms.Compose(
[transforms.CenterCrop(size=self.size), transforms.ToTensor()])
else:
transformer = transforms.Compose([transforms.ToTensor()])
if folder_path.endswith('/'):
folder_path = folder_path[: -1]
image_folder_path = folder_path
folder = ImageFolder(root=image_folder_path, transform=transformer)
self.image_folder = folder
if self.generate_id_from_file_name:
# use image name as its sample id
file_name = self.image_folder.imgs
ids = []
for name in file_name:
sample_id = name[0].split(
'/')[-1].replace(self.file_suffix, '')
ids.append(sample_id)
self.sample_ids = ids
def __getitem__(self, item):
if self.return_label:
item = self.image_folder[item]
return item[0].type(
self.dtype), torch.tensor(
item[1]).type(
self.label_dtype)
else:
return self.image_folder[item][0].type(self.dtype)
def __len__(self):
return len(self.image_folder)
def __repr__(self):
return self.image_folder.__repr__()
def get_classes(self):
return np.unique(self.image_folder.targets).tolist()
def get_sample_ids(self):
return self.sample_ids
if __name__ == '__main__':
pass