pytorch
文章平均质量分 77
sinat_24395003
先学使用轮子,再学造轮子,再自己造轮子
展开
-
NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens
#transformer.generation_logits_process NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens目的是生成不重复的ngramimport torchfrom typing import List, Iterabledef _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): generated_ngra..原创 2021-04-09 15:43:11 · 412 阅读 · 1 评论 -
pad_sequence,pack_padded_sequence,pad_packed_sequence
参照https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorchimport torchfrom torch import nnseq_batch = [torch.tensor([[1, 1], [2, 2], [3, 3], [4.原创 2020-12-10 17:35:46 · 122 阅读 · 0 评论 -
不均衡样本的sampler构建 Imbalanced Dataset Sampler
from fastNLP.io import SST2Pipefrom fastNLP import DataSetIterfrom torchsampler import ImbalancedDatasetSamplerpipe = SST2Pipe()databundle = pipe.process_from_file()vocab = databundle.vocabs['words']print(databundle)print(databundle.datasets['train.原创 2020-08-12 10:06:51 · 1392 阅读 · 4 评论