https://blog.csdn.net/qq_19446965/article/details/120110169
方法一:将陌生序列标记为Unknown
如果将LabelEncoder.transform将训练集转换为编码序列,则在测试集上使用时如果遇到新的值,则可能会报错。'<Unknown>'
我们可以继承LabelEncoder并重写fit和transform。 如果您有一个新标签,它将被分配为未知类。
from sklearn.preprocessing import LabelEncoder as LEncoder
class LabelEncoder(LEncoder):
def fit(self, y):
"""
This will fit the encoder for all the unique values
and introduce unknown value
:param y: A list of string
:return: self
"""
return super(LabelEncoder, self).fit(list(y) + ['Unknown'])
def transform(self, y):
"""
This will transform the y to id list where the new values
get assigned to Unknown class
:param y:
:return: array-like of shape [n_samples]
"""
new_y = ['Unknown' if x not in set(self.classes_) else x for x in y]
return super(LabelEncoder, self).transform(new_y)
样本用法:
country_list = ['Argentina', 'Australia', 'Canada', 'France', 'Italy', 'Spain', 'US', 'Canada', 'Argentina, ''US']
label_encoder = LabelEncoder()
label_encoder.fit(country_list)
print('country_list: ', label_encoder.classes_) # you can see new class called Unknown
print('encode_country_list: ', label_encoder.transform(country_list))
new_country_list = ['Canada', 'France', 'Italy', 'Spain', 'US', 'India', 'Pakistan', 'South Africa']
print('new_encode_country_list: ', label_encoder.transform(new_country_list))
运行结果:
country_list: ['Argentina' 'Argentina, US' 'Australia' 'Canada' 'France' 'Italy' 'Spain' 'US' 'Unknown']
encode_country_list: [0 2 3 4 5 6 7 3 1]
new_encode_country_list: [3 4 5 6 7 8 8 8]
结果编码为8、8、8。
方法二:更新序列编码
还有另一种方法就是更新序列编码:
- 维护一个序列list,存无重复序列features ;
- 当有新的序列的时候,加入到序列features ,更新list;
from sklearn.preprocessing import LabelEncoder as LEncoder
class LabelEncoder(LEncoder):
def __init__(self):
"""
It differs from LabelEncoder by handling new classes
and increase the values for it.
"""
self.features = tuple()
def fit(self, y):
"""
This will fit the encoder for all the unique values
and introduce unknown value
:param y: A tuple of sequence(string)
:return: self
"""
self.set_features(tuple(set(y)))
return super(LabelEncoder, self).fit(self.encode_seqs(y))
def add_features(self, new_features):
"""
Add features
@param new_features: A list of features(string)
"""
self.features = self.features + new_features
def set_features(self, new_features):
"""
Set features
@param new_features: A list of features(string)
"""
self.features = new_features
def transform(self, y):
"""
This will transform the y to id list where the new values
get assigned to Unknown class
:param y: A list of sequence(string)
:return: array-like of shape [n_samples]
"""
increase_features = tuple(x for x in y if x not in set(self.features))
if increase_features:
self.add_features(increase_features)
super(LabelEncoder, self).fit(self.encode_seqs(self.features))
print(f"new classes_: {self.classes_}")
return super(LabelEncoder, self).transform(self.encode_seqs(y))
def encode_seqs(self, seqs):
"""
Encode a sequence as a list of numbers
@param seqs: A list of sequence(string)
@return: array-like of shape [n_seqs]
"""
return [self.features.index(x) for x in seqs]
运行结果:
country_list: [0 1 2 3 4 5 6 7]
encode_country_list: [1 0 7 6 5 4 3 7 2]
new classes_: [ 0 1 2 3 4 5 6 7 8 9 10]
new_encode_country_list: [ 7 6 5 4 3 8 9 10]
结果编码为8、9、10。
————————————————
版权声明:本文为CSDN博主「Rnan-prince」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_19446965/article/details/120110169