abp竞赛-之-文本文件词频查询 优化报告
HouSisong@GMail. 2007.03.15com
tag: abp,单词统计,比赛,hash,速度优化,优化报告
摘要: 以前参加过几次abp论坛的比赛 http://www.allaboutprogram.com/bb (现在的www.cpper.com/c)
其中的一个竞赛的题目是《文本文件词频查询》,本文章把自己的参赛代码的优化的思路整理出来;
很多时候优化后的版本最高达到了STL实现版本的20倍!
(2007.04.09 确认从MFC移植过来的时候引入了一个bug,“FILE* file=fopen(argv[1], "r" ); ”应该为“FILE* file=fopen(argv[1], "rb" ); ” 找了我好久:( 谢谢 hf1414 !)
(2007.03.17修正一个在vc2005编译器下访问vector的bug,将代码 “TNode** end=&(_vbase[_hash_power]); ” 改为 “base_t::iterator end=_vbase.end();” )
(abp现在还能够访问(只读),会员很多都“搬迁”到了 www.cpper.com/c :)
文本文件词频查询 竞赛要求:
Compiler:VC6/VC.net/VC.net 2003
评判标准:正确性+ 速度
截止时间:2003年10月11日前(含)
方法:每个人可以多次提交。每次提交完了,我会告诉你你的成绩和最快的人的成绩。
内容:
一个文件,仅由大小写字母,空格和换行符组成。我们称一个词为连续的大小写字符,两边是空格或者文件头/ 尾。词大小写敏感。
某个词的词频是这个词在这个文件里面出现的次数。
要求,输入一个文件(至少有一个词,并且最大词频的词只有一个),输出那个词频最大的词。
譬如,输入:
aaa bbb
ccc ddd
aaa
输出:
aaa
补充一句:文件可能非常大。(xxxM,xG)
还有就是,文件中不会出现TAB。
(测试程序的时候,我们将vc目录中的源代码文件合成了一个数据文件来作为测试数据)
一个“标准”C++实现版本: (可以作为一个STL使用的实例:)
#include <iostream>
#include <fstream>
#include <string>
#include <map>
#include <time.h>
using namespace std;
int main(int argc, char* argv[])
{
//assert(argc==2);
clock_t start= clock();
const char* file_name=argv[1 ];
ifstream in_file(file_name);
map<string,int> word_table;
string max_word;
long max_count=0 ;
string word;
while(in_file >> word)
{
long old_count_inc=(++word_table[word]);
if(old_count_inc>max_count)
{
max_word=word;
max_count=old_count_inc;
}
}
cout<<" Word: "<<max_word<<" Count: "<<max_count<< endl;
cout<<" Seconds = "<<( (double)(clock()-start)/CLOCKS_PER_SEC )<< endl;
return 0 ;
}
我用的测试编译器vc6.0 , CPU赛扬2.0G
下面的代码很多时候速度是上面的版本的20倍,源代码如下(优化说明在代码之后);
(我以前提交的代码使用了MFC库,为了容易编译和理解,我做了一些代码调整,去除MFC依赖,把一个复杂的代码循环展开删除了,可能慢了10%)
#include < stdio.h >
#include < time.h >
#include < iostream >
#include < string >
#include < vector >
#include < algorithm >
namespace {
class CMyAllot
{
enum { chunk_size = 1024 * 256 }; // 块大小
char * _cur;
char * _end;
std::vector < char *> _vector;
void * _new_else(unsigned int size);
public :
CMyAllot() :_end( 0 ),_cur( 0 ) { }
virtual ~ CMyAllot() { if ( ! _vector.empty()) DelAll(); }
inline void * _fastcall New(unsigned int size)
{
size = ((size + 3 ) >> 2 << 2 ); // 4字节边界对齐
if (( int )size < (_end - _cur)) // 够用
{
char * result = _cur;
_cur += size;
return result;
}
else // 不够用
return _new_else(size);
}
void DelAll()
{
for ( int i = 0 ;i < ( int )_vector.size(); ++ i)
delete [] (_vector[i]);
_vector.clear();
}
};
void * CMyAllot::_new_else(unsigned int size)
{
if (size > (chunk_size >> 2 )) // 不够用,而且需要的空间较大
{
char * result = new char [size];
char * old_back = _vector.back();
_vector[_vector.size() - 1 ] = result;
_vector.push_back(old_back);
return result;
}
else // 不够用,开辟新的空间
{
char * result = new char [chunk_size];
_cur = result + size;
_end = result + chunk_size;
_vector.push_back(result);
return result;
}
}
struct TNode // hash表使用的节点类型(链表)
{
TNode * pNext;
unsigned int count;
char str[ 1 ]; // 不一定只有一个字节,会根据字符串分配空间
struct TComp // 返回时的排序准则
{
bool operator ()( const TNode * l, const TNode * r)
{
if ((l -> count) == (r -> count))
{
return std:: string ( & l -> str[ 0 ]) < ( & r -> str[ 0 ]);
}
else
return (l -> count) > (r -> count);
}
};
};
inline unsigned int _fastcall hash_value( char * begin, char * end)
{
unsigned int result = 0 ;
do {
result = 5 * result + ( * begin); // 利用asm: lea reg0,[reg1*4+reg1],并且5是质数
} while (( ++ begin) != end);
return result;
}
inline unsigned int _fastcall hash_value( char * pstr)
{
unsigned int result = 0 ;
do { result = 5 * result + ( * pstr); ; // 利用asm: lea reg0,[reg1*4+reg1],并且5是质数
} while (( * ( ++ pstr)));
return result;
}
// 测试字符串是否相同, 如果需要不区分大小写,修改这个函数和hash函数就可以了
inline bool _fastcall test_str_EQ( char * begin, char * end, char * str)
{
// for (;begin!=end;++begin,++str)
// if ( (*begin)!=*(str) ) return false;
do {
if ( ( * begin) !=* (str) ) return false ;
++ begin; ++ str;
} while (begin != end);
return true ;
}
}
class CHashSet
{
typedef std::vector < TNode *> base_t;
inline unsigned int hash_index( char * begin, char * end) const
{ return hash_value(begin,end) & (_hash_mask); }
inline unsigned int hash_index( char * pstr) const
{ return hash_value(pstr) & (_hash_mask); }
void resize();
void _fastcall move_insert(base_t & v,TNode * pOldNode) const ;
TNode * _fastcall NewNode( char * begin, char * end);
void Sort(base_t & v,unsigned int sortCount);
unsigned int _hash_power;
unsigned int _hash_mask;
unsigned int _node_count;
base_t _vbase;
CMyAllot _allot;
void _fastcall else_insert(TNode * pNode, char * begin, char * end);
public :
CHashSet();
virtual ~ CHashSet();
unsigned int size() const { return _node_count; }
unsigned int sum();
void _fastcall insert( char * begin, char * end);
void GetStrList(std::ostream & cout,unsigned int sortCount);
};
CHashSet::CHashSet()
:_hash_power( 2 ),_vbase((unsigned int )(_hash_power),(TNode * ) 0 ) // 注意次序
{
_node_count = 0 ;
_hash_mask = _hash_power - 1 ; // _hash_power=1<<n;
}
CHashSet:: ~ CHashSet()
{
_allot.DelAll();
}
unsigned int CHashSet::sum()
{
unsigned int sum = 0 ;
if (_node_count > 0 )
{
base_t::iterator end = _vbase.end();
for (base_t::iterator i = _vbase.begin();i < end; ++ i)
{
TNode * pNode = ( * i);
while (pNode != 0 )
{
sum += pNode -> count;
pNode = pNode -> pNext;
}
}
}
return sum;
}
void _fastcall CHashSet::insert( char * begin, char * end)
{
unsigned int index = hash_index(begin,end);
TNode * pNode = _vbase[index];
if ( ! pNode) // 节点还没有使用
{
_vbase[index] = NewNode(begin,end);
++ _node_count;
}
else
{
if (test_str_EQ(begin,end,pNode -> str)) // 累加
++ (pNode -> count);
else
else_insert(pNode,begin,end);
}
}
void _fastcall CHashSet::else_insert(TNode * pNode, char * begin, char * end)
{
while ( true )
{
if ( ! (pNode -> pNext))
{
pNode -> pNext = NewNode(begin,end);
++ _node_count;
if (_node_count >= (_hash_power))
resize();
break ;
}
else if (test_str_EQ(begin,end,pNode -> pNext -> str))
{
++ (pNode -> pNext -> count);
break ;
}
pNode = pNode -> pNext;
};
}
void _fastcall CHashSet::move_insert(base_t & v,TNode * pOldNode) const
{
TNode *& pNode = v[hash_index(pOldNode -> str)];
pOldNode -> pNext = 0 ;
if ( ! pNode) // 节点还没有使用
{
pNode = pOldNode;
}
else
{
if ( ! pNode -> pNext)
{
pNode -> pNext = pOldNode;
}
else
{
TNode * pListNode = pNode -> pNext;
while (pListNode -> pNext != 0 )
{ pListNode = pListNode -> pNext; }
pListNode -> pNext = pOldNode;
}
}
}
TNode * _fastcall CHashSet::NewNode( char * begin, char * end)
{
TNode * pNode = (TNode * )(_allot.New( sizeof (TNode) + end - begin));
pNode -> pNext = 0 ;
pNode -> count = 1 ;
char * i = pNode -> str;
// for (;begin!=end;++i,++begin)
// (*i)=(*begin);
do {
( * i) = ( * begin); ++ i, ++ begin;
} while (begin != end);
( * i) = char ( 0 );
return pNode;
}
void CHashSet::resize()
{
if (_node_count >= (_hash_power))
{
base_t::iterator end = _vbase.end();
_hash_power <<= 2 ;
_hash_mask = (_hash_power) - 1 ;
base_t new_vbase(_hash_power,(TNode * ) 0 );
for (base_t::iterator i = _vbase.begin();i != end; ++ i)
{
TNode * pNode = ( * i);
while (pNode != 0 )
{
TNode * temp = pNode -> pNext;
move_insert(new_vbase,pNode);
pNode = temp;
}
}
_vbase.swap(new_vbase);
}
}
/// /
void CHashSet::Sort(base_t & v,unsigned int sortCount)
{
if (sortCount == 1 )
{
v.resize( 1 );
base_t::iterator end = _vbase.end();
TNode * maxNode = _vbase[ 0 ];
TNode::TComp op;
for (base_t::iterator i = _vbase.begin();i != end; ++ i)
{
TNode * pNode = ( * i);
while (pNode != 0 )
{
if ( (maxNode == 0 ) || (op(pNode,maxNode)) )
maxNode = pNode;
pNode = pNode -> pNext;
}
}
v[ 0 ] = maxNode;
}
else
{
v.resize(_node_count);
int index = 0 ;
if (_node_count > 0 )
{
TNode ** end =& (_vbase[_hash_power]);
for (TNode ** i =& (_vbase[ 0 ]);i != end; ++ i)
{
TNode * pNode = ( * i);
while (pNode != 0 )
{
v[index] = pNode;
++ index;
pNode = pNode -> pNext;
}
}
}
std::partial_sort(v.begin(),v.begin() + sortCount,v.end(),TNode::TComp());
}
}
void CHashSet::GetStrList(std::ostream & cout,unsigned int sortCount)
{
if (_node_count >= 1 )
{
if (sortCount == 0 )
sortCount = _node_count;
else if (_node_count < sortCount)
sortCount = _node_count;
base_t v;
Sort(v,sortCount);
for ( int i = 0 ;i < ( int )sortCount; ++ i)
{
std::cout << " 单词: " << ( & (v[i] -> str[ 0 ])) << " 计数: " << (v[i] -> count) << std::endl;
}
}
}
class CWords
{
private :
enum { cibuf_size = 4096 }; // 缓冲区最佳大小
int buf_size; // 动态缓冲区大小
char * pBuf; // 指向缓冲区
static void CreateGainTab(); // 构造“词”分析用的表
int privateGainWord( int dx, int start_offset, bool isEndGain);
inline int GainWord( int dx, int start_offset); // 从缓冲区获取词;
inline void endGainWord( int dx, int start_offset); // 从缓冲区获取词,处理文件尾;
void _fastcall PushWord( char * begin, char * end);
__int64 _CPUCount;
CHashSet _hash_set;
public :
CWords();
virtual ~ CWords();
void toDo(FILE * file); // 循环读取文件数据到内存缓冲区
void GetResult(std::ostream & cout,unsigned int sortCount);
};
namespace {
static unsigned int GainTab[ 256 ]; // 进行词法分析的表
}
// 构造“词”分析用的表
void CWords::CreateGainTab()
{
//
static bool IsDo = false ;
if (IsDo) return ;
for ( int i = 0 ;i < 256 ; ++ i)
{
if ( ((i >= ' A ' ) && (i <= ' Z ' ))
|| ((i >= ' a ' ) && (i <= ' z ' ))
// || (i=='_')
// || ((i>='0')&&(i<='9'))
)
GainTab[i] = unsigned int ( - 1 );
else
GainTab[i] = 0 ;
}
IsDo = true ;
}
CWords::CWords()
{
}
CWords:: ~ CWords()
{
}
#define asm __asm
__declspec( naked ) __int64 CPUCycleCounter() // 获取当前CPU周期计数(CPU周期数)
{
asm
{
RDTSC // 0F 31 // eax,edx
ret
}
}
// 循环读取文件数据到内存缓冲区
void CWords::toDo(FILE * file)
{
_CPUCount = ::CPUCycleCounter();
std::vector < char > BufData(cibuf_size);
buf_size = BufData.size();
pBuf =& BufData[ 0 ];
// get file length
fseek(file, 0 ,SEEK_END);
int file_length = ftell(file);
fseek(file, 0 ,SEEK_SET);
int file_pos = 0 ;
CreateGainTab();
int dx = 0 ;
int start_offset = 0 ;
while ( true )
{
if (file_pos + (buf_size - dx) <= file_length)
{
fread(pBuf + dx,buf_size - dx, 1 ,file);
file_pos += (buf_size - dx);
dx = GainWord(dx,start_offset);
start_offset = 0 ;
if (dx < 0 ) // 处理超长单词
{
start_offset = buf_size + dx; // 放大缓冲区
dx = buf_size;
BufData.resize(dx * 2 );
buf_size = BufData.size();
pBuf =& BufData[ 0 ];
}
else // if ( (dx<(cibuf_size>>1)) && (buf_size>(cibuf_size<<1)) )
{
// BufData.resize(cibuf_size); // 减小缓冲区
// pBuf=&BufData[0];
// buf_size=BufData.size();
}
}
else
{
int bordercount = ( int )(file_length - file_pos);
if (bordercount > 0 )
{
fread(pBuf + dx,bordercount, 1 ,file);
buf_size = dx + bordercount;
// file_pos+=(bordercount);
endGainWord(dx,start_offset);
}
break ; // end while
}
}
_CPUCount = ::CPUCycleCounter() - _CPUCount;
}
int CWords::privateGainWord( int dx, int start_offset, bool isEndGain)
{
char * pStart = pBuf + start_offset;
char * pEnd = pBuf + buf_size;
int IsInWord = (dx != 0 ) ? int ( - 1 ): 0 ; // 是否处于“词”中
char * i = pBuf + dx;
for (;i != pEnd; ++ i)
{
if (IsInWord ^ GainTab[ * (unsigned char * )i])
{
if (IsInWord)
PushWord(pStart,i);
else
pStart = i;
IsInWord = ( ~ IsInWord);
}
}
/
dx = 0 ;
if (IsInWord)
{
if (isEndGain)
PushWord(pStart,pEnd); // 最末尾的一个词
else
{
dx = pEnd - pStart;
if (dx > (buf_size >> 1 )) // 超长单词特殊处理
dx = ( - dx); // 特殊标记!
else
{
for ( int i = 0 ;i < dx; ++ i) // 把没有处理完的单词拷贝到缓冲区开头
pBuf[i] = pStart[i];
}
}
}
return dx;
}
int CWords::GainWord( int dx, int start_offset)
{
return privateGainWord(dx,start_offset, false );
}
void CWords::endGainWord( int dx, int start_offset)
{
privateGainWord(dx,start_offset, true );
}
void CWords::GetResult(std::ostream & cout, unsigned int sortCount)
{
std::cout << " 无重复单词数: " << _hash_set.size() << " 单词总数: " << _hash_set.sum() << std::endl;
std::cout << " CPU周期计数: " << ( long )_CPUCount << std::endl;
_hash_set.GetStrList(cout,sortCount);
}
inline void _fastcall CWords::PushWord( char * begin, char * end)
{
_hash_set.insert(begin,end);
}
// /
int CreateTxtFile( char * argv[]);
int toWork( int argc, char * argv[]);
const char sParameter [] = " Cpt_hss filename [/N] " ;
// 主程序
int main( int argc, char * argv[])
{
if (argc <= 1 )
{
std::cout << ( " 请输入文件名称! " );
std::cout << sParameter;
std::cout << std::endl;
return 0 ;
}
if (std:: string (argv[ 1 ]) == " /? " )
{
std::cout << ( " 统计文件中单词出现频率。 " );
std::cout << (sParameter);
std::cout << ( " filename 指定需要进行统计的文件的名称 " );
std::cout << ( " [/N] 显示出现频率最高的前N个单词; " );
std::cout << ( " 如果单词出现频率相同,则按字母顺序排列; " );
std::cout << ( " N默认为1; " );
std::cout << ( " 当N=0时,表示全部显示。 " );
std::cout << std::endl;
return 0 ;
}
return toWork(argc,argv);
}
int toWork( int argc, char * argv[])
{
clock_t start = clock();
FILE * file = fopen(argv[ 1 ], " rb " );
if (file == 0 )
{
std::cout << ( " 打开文件时发生错误! " );
std::cout << (sParameter);
return 0 ;
}
unsigned int sortCount = 1 ;
if (argc == 3 )
sortCount = atoi(argv[ 2 ] + 1 );
CWords words;
words.toDo(file);
fclose(file);
words.GetResult(std::cout,sortCount);
std::cout << " Seconds = " << ( ( double )(clock() - start) / CLOCKS_PER_SEC ) << std::endl;
return 0 ;
}
重点优化说明: (这是本篇文章的重点,讲解一些基本的优化策略)
1.在读取文件方面,使用了一个自己管理的内存缓冲区来读取文件的数据;
(这样处理以后读文件占的时间约占总时间的1/7,还可以进一步优化:
进一步改进方案a:可以考虑用另一个线程异步来加载文件数据(当前处理大量文件数据的高效方案);
进一步改进方案b:如果文件不太大可以考虑使用内存映射技术来优化这一块,代码也简单很
多,而单词的表示也可以采用一个指针加一个长度(或者用头尾两个指针,或者一个指针+哨兵
位(推荐))来表示,从而避免一次深拷贝)
2.建立了一个查询表GainTab[256]用来判断一个字母是否是单词还是空白区域;
比如:可以把( ((C>='A')&&(C<='Z'))||((C>='a')&&(C<='z')) ) 简写为 ( GainTab[C]!=0 )
(其实也可以建立一个64k的表来捕捉状态,同时用两个字节来查表...)
3.把查找单词的扫描过程理解为从单词区域到空白区域的状态转换(这句可能不好理解);
比如一般常见的实现伪代码:
while (i!=pEnd)
{
while((i!=pEnd)&&(!GainTab[*(unsigned char*)i])) //寻找到单词开头
++i;
pStart=i;
while((i!=pEnd)&&(GainTab[*(unsigned char*)i])) //寻找该单词结束位置
++i;
if (pStart!=i)
PushWord(pStart,i);
}
我的代码:
{
if (IsInWord^GainTab[*(unsigned char*)i]) //捕捉所属区域状态的变化
{
if (IsInWord)
PushWord(pStart,i);
else
pStart= i;
IsInWord=(~ IsInWord);
}
}
该算法处理两个状态:是否在单词中、“是否在单词中”的状态是否改变;
从而消除了内部的一个循环框架,这在单词和空际较小时将带来更多好处;
(在本程序中可能所起作用不大,这里耗的时间不多,反而调用PushWord的花掉的时间更多)
(还有一个有用的见解:“经过3个字节最多能够计数一个单词”,比如利用这个观点可以建立
2字节或3字节的查询表(表的大小的取舍需要考虑CPU的缓冲区大小),同时处理更多的字节)
4.为了优化单词使用的内存,减少动态内存分配,自定义了一个CMyAllot类来管理内存的分配
5.我使用了一个自定义的hash表CHashSet(准确点应该叫做map)来储存找到的单词(hash表具有平
均常数时间的单词查找能力),表的大小会随着无重复单词数的增加而动态增长:某个HashItem
不可用时,会把新的单词加到HashItem后面,即HashItem形成一个单向list,当hash表的负债超过
某个阈值的时候,就会增大表的大小,然后所有的元素重新转移到新的表;
6.我的hash表的大小只可能为2的整数次方,所以hash值在映射到HashItem的序号时可以使用快速
的&运算(hash_value&_hash_mask); 等价于(hash_value%hash_size) , 优化掉一次求余运算
(求余和除法都是很慢的操作)
补充: 我尝试过把char字符流当作wchar_t* 流来处理,希望提高吞吐量
但为保证结果正确代码逻辑变得稍微复杂了一些,结果在我的机子上速度几乎没有变!