ML.NET有很多内置的转换器,但是我们不可能涵盖所有内容。不可避免地,您将需要执行自定义的用户定义操作。为此,我们添加了MLContext.Transforms.CustomMapping
就是为了这个目的:这是用户定义的数据的任意映射。
假设我们有一个带有float数据的'Income'列的数据集,我们要计算'Label',如果收入超过50000,则等于true
,否则等于false
。
这是我们如何通过自定义转换器执行此操作的方法:
// 为我们打算使用的所有输入列定义一个类。
class InputRow
{
public float Income { get; set; }
}
// 为我们打算产生的所有输出列定义一个类。
class OutputRow
{
public bool Label { get; set; }
}
public static IDataView PrepareData(MLContext mlContext, IDataView data)
{
// 定义操作代码。
Action<InputRow, OutputRow> mapping = (input, output) => output.Label = input.Income > 50000;
// 创建一个定制的估计器并转换数据。
var estimator = mlContext.Transforms.CustomMapping(mapping, null);
return estimator.Fit(data).Transform(data);
}
您还可以在估计器管道中插入自定义映射:
public static ITransformer TrainModel(MLContext mlContext, IDataView trainData)
{
// 使用自定义操作。
Action<InputRow, OutputRow> mapping = (input, output) => output.Label = input.Income > 50000;
// 构建学习管道。
var estimator = mlContext.Transforms.CustomMapping(mapping, null)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label"));
return estimator.Fit(trainData);
}
请注意,您需要将mapping
操作变成“纯函数”
它应该是可重入的(我们将从多个线程同时调用它)
它不应该有副作用(我们可以在任何时候任意调用,或忽略调用)
一个重要的警告是:如果希望自定义转换成为已保存模型的一部分,则需要为其提供contractName
。在加载时,您需要向MLContext注册自定义转换器。
下面是一个完整的示例,用于保存和加载带有自定义映射的模型。
/// <summary>
/// 一个类包含我们的模型所需的自定义映射功能。
///
/// It has a <see cref="CustomMappingFactoryAttributeAttribute"/> on it and
/// derives from <see cref="CustomMappingFactory{TSrc, TDst}"/>.
/// </summary>
[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
public class CustomMappings : CustomMappingFactory<InputRow, OutputRow>
{
// 这是自定义映射。我们现在将它分离为一个方法,以便在训练和加载中都可以使用它。
public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;
// 当加载模型以获取映射操作时,将调用此工厂方法。
public override Action<InputRow, OutputRow> GetMapping()
{
return IncomeMapping;
}
}
// 构建学习管道。请注意,我们现在为自定义映射提供了一个约定名称:否则我们将无法保存模型。
var estimator = mlContext.Transforms.CustomMapping<InputRow, OutputRow>(CustomMappings.IncomeMapping, nameof(CustomMappings.IncomeMapping))
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label"));
// 如果内存足够,我们可以将数据缓存在内存中,以避免在多次访问文件时从文件中加载数据。
var cachedTrainData = mlContext.Data.Cache(trainData);
// 训练模型
var model = estimator.Fit(cachedTrainData);
// 保存模型。
using (var fs = File.Create(modelPath))
mlContext.Model.Save(model, fs);
// 现在假设我们在一个不同的过程中。
// 向ComponentCatalog注册包含“CustomMappings”的程序集,以便在加载模型时可以找到它。
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
// 现在我们可以加载模型了。
ITransformer loadedModel = newContext.Model.Load(modelPath, out var schema);
欢迎关注我的个人公众号”My IO“