字符串写时拷贝实现原理:
通过对象的引用计数来减少内存的申请开销,经过拷贝构造或赋值函数得到的所有对象,在还没有对其中某个某个对象进行修改操作时,都共享一个对象的内存。否则,如果原有对象的引用计数大于1时,将会为修改的对象分配新的内存,并在原有的对象引用计数中减一。
方法1
成员变量是字符指针类型,字符串前四个字节用来存放当前对象被引用的次数。
class CString
{
public:
CString(char *p = NULL)
{
if (p == NULL)
{
mptr = new char[4+1];
((int*)mptr)[0] = 1;
mptr[4] = '\0';
}
else
{
mptr = new char[strlen(p)+4+1];
((int *)mptr)[0] = 1;
strcpy_s(mptr+4, strlen(p) + 1, p);
}
}
CString(const CString &src)
{
mptr = src.mptr;
((int*)mptr)[0]++;
}
CString& operator=(const CString &src)
{
if (this != &src)
{
((int*)mptr)[0]--;
if (((int*)mptr)[0] == 0)
{
delete []mptr;
mptr = NULL;
}
mptr = src.mptr;
((int*)mptr)[0]++;
}
return *this;
}
~CString()
{
((int*)mptr)[0]--;
if (((int*)mptr)[0] == 0)
{
delete []mptr;
mptr = NULL;
}
}
void erase(const char &ch) // 'a'
{
//处理写时拷贝
if (((int*)mptr)[0] > 1)
{
((int*)mptr)[0]--;
char *ptmp = new char[strlen(mptr + 4) + 4 + 1];
strcpy_s(ptmp+4, strlen(mptr + 4) + 1,mptr+4);
mptr = ptmp;
((int*)mptr)[0] = 1;
}
char *p = mptr + 4;
for (; *p != '\0'; ++p)
{
if (*p == ch)
{
char *q = p + 1;
for (; *q != '\0'; ++q)
{
*(q - 1) = *q;
}
*(q - 1) = '\0';
return;
}
}
}
private:
char *mptr;
friend ostream& operator<<(ostream &out, const CString& src);
};
ostream& operator<<(ostream &out, const CString& src)
{
out << src.mptr + 4;
return out;
}
方法2
成员变量类型设置为结构体类型,其中有一个整型变量专门用来计数,另一个变量是字符指针类型。
class CString
{
public:
CString(char *p = NULL)
{
mpnode = new Node(p);
}
CString(const CString& src)
{
mpnode = src.mpnode;
mpnode->cnt++;
}
CString& operator=(const CString& src)
{
if (this == &src)
{
return *this;
}
mpnode->cnt--;
if (mpnode->cnt == 0)
{
delete []mpnode->mptr;
mpnode->mptr = NULL;
delete mpnode;
mpnode = NULL;
}
mpnode = src.mpnode;
mpnode->cnt++;
return *this;
}
~CString()
{
mpnode->cnt--;
if (mpnode->cnt == 0)
{
delete []mpnode->mptr;
mpnode->mptr = NULL;
delete mpnode;
mpnode = NULL;
}
}
void erase(const char &ch)
{
//处理写时拷贝
if (mpnode->cnt > 1)
{
mpnode->cnt--;
Node *pnode = new Node(mpnode->mptr);
mpnode = pnode;
}
char *p = mpnode->mptr;
int size = strlen(mpnode->mptr) + 1;
for (int i = 0; i < size - 1; ++i)
{
if (p[i] == ch)
{
for (int j = i; j < size - 1; ++j)
{
p[j] = p[j+1];
}
return;
}
}
}
private:
struct Node
{
Node(char *ptr = NULL) :cnt(1)
{
if (ptr != NULL)
{
mptr = new char[strlen(ptr) + 1];
strcpy_s(mptr, strlen(ptr) + 1, ptr);
}
else
{
mptr = new char[1];
*mptr = 0;
}
}
int cnt;
char *mptr;
};
Node *mpnode;
friend ostream& operator<<(ostream &out, const CString& src);
};
ostream& operator<<(ostream &out,const CString &src)
{
out << src.mpnode->mptr;
return out;
}
测试:
int main()
{
CString str1 = "hello";
CString str2 = str1;
cout << str1 << endl;
cout << str2 << endl;
str1.erase('e');
cout << str1 << endl;
cout << str2 << endl;
CString str3;
str3 = str1;
cout << str3 << endl;
return 0;
}
VS运行结果:
hello
hello
hllo
hello
hllo