#!/usr/bin/env python
import numpy
import json
import sys
import fileinput
from collections import OrderedDict
import torch
# 计算pt文件下的单词数
# # 计算源端的词汇
# # src_vocab_a = checkpoint_a['src'].base_field.vocab
# tgt_vocab_a = checkpoint_a['tgt'].base_field.vocab
#
# # 总单词数为28313815
# freqs_counter_all = tgt_vocab_a.freqs
# freqs_all = sum(freqs_counter_all.values())
# print('总单词数为{}'.format(freqs_all))
#
# # 词汇表为20k内的单词
# freqs_counter_z = freqs_counter_all.most_common(20000)
# freqs_z = sum(map(lambda x: x[1], freqs_counter_z))
# print('词汇表为20k内的单词频数之和:{0},占总单词数:{1}%'.format(freqs_z,(freqs_z/freqs_all)*100))
#
# # 词汇表为30k内的单词
# freqs_counter_x = freqs_counter_all.most_common(30000)
# freqs_x = sum(map(lambda x: x[1], freqs_counter_x))
# print('词汇表为30k内的单词频数之和:{0},占总单词数:{1}%'.format(freqs_x,(freqs_x/freqs_all)*100))
#
# # 词汇表为50k内的单词
# freqs_counter_a = freqs_counter_all.most_common(50000)
# freqs_a = sum(map(lambda x: x[1], freqs_counter_a))
# print('词汇表为50k内的单词频数之和:{0},占总单词数:{1}%'.format(freqs_a,(freqs_a/freqs_all)*100))
#
# # 词汇表为80k内的单词
# freqs_counter_c = freqs_counter_all.most_common(80000)
# freqs_c = sum(map(lambda x: x[1], freqs_counter_c))
# print('词汇表为80k内的单词频数之和:{0},占总单词数:{1}%'.format(freqs_c,(freqs_c/freqs_all)*100))
#
# # 词汇表为100k内的单词
# freqs_counter_y = freqs_counter_all.most_common(100000)
# freqs_y = sum(map(lambda x: x[1], freqs_counter_y))
# print('词汇表为100k内的单词频数之和:{0},占总单词数:{1}%'.format(freqs_y,(freqs_y/freqs_all)*100))
# 计算句子数
# a = '/home/think/OpenNMT-py-master-sdc/data/ldc-zh-en-data/data/corpus.tc.zh'
# b = '/home/think/OpenNMT-py-master-sdc/data/ldc-zh-en-data/data/corpus.tc.en'
#
# with open(a, 'r', encoding='utf-8') as f:
# print('总句子数为:{}'.format(len(f.readlines())))
#
# # with open(b, 'r', encoding='utf-8') as f:
# # print(len(f.readlines()))
#
# sents_num = 0
# length = 60
# with open(a, 'r', encoding='utf-8') as f:
# for line in f:
# words_in = line.strip().split(' ')
# if len(words_in) <= length:
# sents_num += 1
# print('长度小于等于{0}的句子有{1}个'.format(length, sents_num))
# 计算数据集单词数
def main():
for filename in sys.argv[1:]:
print('Processing', filename)
word_freqs = OrderedDict()
with open(filename, 'r', encoding='utf-8') as f:
for line in f:
words_in = line.strip().split(' ')
for w in words_in:
if w not in word_freqs:
word_freqs[w] = 0
word_freqs[w] += 1
words = word_freqs.keys()
freqs = word_freqs.values()
# 总单词数(未去重)
print(numpy.sum(list(freqs)))
# 总单词数(去重)
print(len(words))
if __name__ == '__main__':
main()