rust使用fasttext crate训练模型并进行预测
- dependencies
[dependencies]
fasttext = "*"
- main.rs
use fasttext::{FastText, Args, ModelName, LossName};
const TRAIN_FILE: &str = "data.train";
const TEST_FILE: &str = "data.test";
const MODEL: &str = "model.bin";
fn main() -> Result<(), Box<Error>> {
// train
let mut args = Args::new();
args.set_input(TRAIN_FILE);
args.set_model(ModelName::SUP);
args.set_loss(LossName::SOFTMAX);
let mut ft_model = FastText::new();
ft_model.train(&args).unwrap();
// eval
let preds = test_data.iter().map(
|x| ft_model.predict(x.text.as_str(), 1, 0.0)
);
let test_labels = test_data.iter().map(|x| x.into_labels());
let mut hits = 0;
let mut correct_hits = 0;
let preds_clone = preds.clone();
for (predicted, actual) in preds.zip(test_labels) {
let predicted = predicted?;
// take the first prediction
let predicted = &predicted[0];
if predicted.clone().label == actual {
correct_hits += 1;
}
hits += 1;
}
assert_eq!(hits, preds_clone.len());
println!("accuracy={} ({}/{} correct)", correct_hits as f32 / hits as f32, correct_hits, preds_clone.len());
ft_model.save_model(MODEL)?;
Ok(())
}