词向量源码解析:(6.3)fasttext源码解析之文本分类2

下面我们看看怎么样一步一步进行的文本分类

首先看main.cc函数

int main(int argc, char** argv) {//处理输入的参数,第一个参数是执行文件本身
  if (argc < 2) {
    printUsage();
    exit(EXIT_FAILURE);
  }
  std::string command(argv[1]);//第二个参数高旭fastttext做什么
  if (command == "skipgram" || command == "cbow" || command == "supervised") {
    train(argc, argv);//训练
  } else if (command == "test") {
    test(argc, argv);//测试
  } else if (command == "quantize") {
    quantize(argc, argv);
  } else if (command == "print-word-vectors") {
    printWordVectors(argc, argv);
  } else if (command == "print-sentence-vectors") {
    printSentenceVectors(argc, argv);
  } else if (command == "print-ngrams") {
    printNgrams(argc, argv);
  } else if (command == "nn") {
    nn(argc, argv);
  } else if (command == "analogies") {
    analogies(argc, argv);
  } else if (command == "predict" || command == "predict-prob" ) {//预测
    predict(argc, argv);
  } else {
    printUsage();
    exit(EXIT_FAILURE);
  }
  return 0;
}

这里提供了很多选择,我们就看训练,测试和预测。我们先用字词连载讲解一下训练。实例化fasttext类,调用类中的train方法。这里用args类对输入的参数进行了包装。

void train(int argc, char** argv) {
  std::shared_ptr<Args> a = std::make_shared<Args>();
  a->parseArgs(argc, argv);
  FastText fasttext;
  fasttext.train(a);
}

下面是args类的内容,首先是构造函数,里面是默认参数

Args::Args() {
  lr = 0.05;
  dim = 100;
  ws = 5;
  epoch = 5;
  minCount = 5;
  minCountLabel = 0;
  neg = 5;
  wordNgrams = 1;
  loss = loss_name::ns;
  model = model_name::sg;
  bucket = 2000000;
  minn = 3;
  maxn = 6;
  thread = 12;
  lrUpdateRate = 100;
  t = 1e-4;
  label = "__label__";
  verbose = 2;
  pretrainedVectors = "";
  saveOutput = 0;


  qout = false;
  retrain = false;
  qnorm = false;
  cutoff = 0;
  dsub = 2;
}

我们再看parseArgs。这个韩式后面都是对输入的参数进行解析,前面的几行对有监督训练,也就是文本分类,进行特别的参数设置。loss用softmax,不再过滤低频词了

void Args::parseArgs(int argc, char** argv) {
  std::string command(argv[1]);
  if (command == "supervised") {
    model = model_name::sup;
    loss = loss_name::softmax;
    minCount = 1;
    minn = 0;
    maxn = 0;
    lr = 0.1;
  } else if (command == "cbow") {
    model = model_name::cbow;
  }

然后我们进入fasttext中的train方法,这个函数构建了要训练的参数并且进行了初始化,然后开启多线程进行训练。

void FastText::train(std::shared_ptr<Args> args) {
  args_ = args;
  dict_ = std::make_shared<Dictionary>(args_);//基本所有自然语言处理的第一步都是构建一个词典
  if (args_->input == "-") {
    // manage expectations
    std::cerr << "Cannot use stdin for training!" << std::endl;
    exit(EXIT_FAILURE);
  }
  std::ifstream ifs(args_->input);
  if (!ifs.is_open()) {
    std::cerr << "Input file cannot be opened!" << std::endl;
    exit(EXIT_FAILURE);
  }
  dict_->readFromFile(ifs);//把语料给词典,词典通过语料去构建词典
  ifs.close();


  if (args_->pretrainedVectors.size() != 0) {//和word2vec一样,参数由两部分组成,一部分是词向量,一部分是上下文向量/label向量,也就input向量和output向量
    loadVectors(args_->pretrainedVectors);
  } else {
    input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);//词向量的个数是单词的数量加上桶的个数。多个ngram或是字符ngram会在一个桶中共享一个向量
    input_->uniform(1.0 / args_->dim);//初始化词向量
  }


  if (args_->model == model_name::sup) {//对于有监督任务就是label向量
    output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
  } else {//对于训练词向量来说就是上下文向量
    output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
  }
  output_->zero();//全部用0初始化


  start = clock();
  tokenCount = 0;
  if (args_->thread > 1) {
    std::vector<std::thread> threads;
    for (int32_t i = 0; i < args_->thread; i++) {//多线程
      threads.push_back(std::thread([=]() { trainThread(i); }));
    }
    for (auto it = threads.begin(); it != threads.end(); ++it) {
      it->join();
    }
  } else {
    trainThread(0);
  }
  model_ = std::make_shared<Model>(input_, output_, args_, 0);//最后训练的结果就是词向量和上下文向量,也叫输入向量和输出向量


  saveModel();
  if (args_->model != model_name::sup) {
    saveVectors();
    if (args_->saveOutput > 0) {
      saveOutput();
    }
  }
}

先看一下trainthread中是怎么训练输入输出向量的

void FastText::trainThread(int32_t threadId) {
  std::ifstream ifs(args_->input);
  utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);//根据线程id定位到不同的位置,每个线程一直执行下去,到了文档的最后再重新从第一行去读


  Model model(input_, output_, args_, threadId);//模型主要就包括的是输入输出向量
  if (args_->model == model_name::sup) {
    model.setTargetCounts(dict_->getCounts(entry_type::label));
  } else {
    model.setTargetCounts(dict_->getCounts(entry_type::word));
  }


  const int64_t ntokens = dict_->ntokens();
  int64_t localTokenCount = 0;
  std::vector<int32_t> line, labels;
  while (tokenCount < args_->epoch * ntokens) {//一共要处理 epoch乘以ntokens个tokens,tokenCount记录一共处理了多少个tokens
    real progress = real(tokenCount) / (args_->epoch * ntokens);//更新alpha的逻辑和word2vec一样
    real lr = args_->lr * (1.0 - progress);
    localTokenCount += dict_->getLine(ifs, line, labels, model.rng);//词典类中的getline得到一行文本line,以及对应的label,line是单词的id数组(向量)
    if (args_->model == model_name::sup) {
      supervised(model, lr, line, labels);//输入模型,learning rate,line以及label

    } else if (args_->model == model_name::cbow) {
      cbow(model, lr, line);
    } else if (args_->model == model_name::sg) {
      skipgram(model, lr, line);
    }
    if (localTokenCount > args_->lrUpdateRate) {//大于一定的阈值打印,这里是100
      tokenCount += localTokenCount;
      localTokenCount = 0;
      if (threadId == 0 && args_->verbose > 1) {
        printInfo(progress, model.getLoss());
      }
    }
  }
  if (threadId == 0 && args_->verbose > 0) {//线程结束以后再打印一下最后的loss
    printInfo(1.0, model.getLoss());
    std::cerr << std::endl;
  }
  ifs.close();
}

下面看supervised函数

void FastText::supervised(Model& model, real lr,
                          const std::vector<int32_t>& line,
                          const std::vector<int32_t>& labels) {
  if (labels.size() == 0 || line.size() == 0) return;
  std::uniform_int_distribution<> uniform(0, labels.size() - 1);//这里假设有多个label,随机去一个,一般就一个label
  int32_t i = uniform(model.rng);
  model.update(line, labels[i], lr);//真正的更新调用model类的update,update会对line中的所有的单词向量取平均,去预测label。更新词向量也是调用这个函数,因为基本的操作都是一样的。
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值