今天面试的时候被问到了一个问题,要我用模板类实现一个shared_ptr. 坦白来说这道题我并不会做,因为我对于shared_ptr的认知仅仅停留在引用计数上面。所以查阅了资料写了一份代码出来,希望对于更深入理解共享智能指针有帮助吧。
首先就是每个指针类应该有一个引用计数变量,以及一个模板指针,用于存放对应的不同类型的指针。对于shared_ptr来说,指针指向相同的位置,引用计数要加1. 同时,使用复制构造函数的时候,应该也会让引用变量加1,因为本身复制构造函数相同于使指针指向同一个地址。同时,应该重载=号,因为等于号也相当于让两个指针指向同一地址,在这里如果左操作符引用计数清零0就需要被析构,同时右操作符的引用计数需要加1.
综合上面的思路,写出来下面的代码:
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <map>
#include <unordered_map>
#include <set>
#include <unordered_set>
using namespace std;
template <typename T>
class Shared_ptr{
public:
T *ptr;//实际指针
int *p_count;//用于引用计数
Shared_ptr(T *p){//第一次指向T类型的指针
ptr = p;
p_count = new int(1);//初始化,第一次指针指向的位置的引用计数应该是1
}
Shared_ptr(const Shared_ptr<T> & p){//复制构造函数
this->ptr = p.ptr;
this->p_count = p.p_count;
//引用计数应该加1
(*this->p_count)++;
}
Shared_ptr<T> operator = (const Shared_ptr<T> & p){//重载等号
if (ptr != p.ptr){
//等号左边的引用计数减1,等号右边的引用计数加1
*p_count--;
if (*p_count == 0){
delete p_count;
delete ptr;
}
//指向=右边操作符,且计数变量+1
ptr = p.ptr;
p_count = p.p_count;
*p_count++;
}
return *this;
}
T *operator->()
{
return ptr;
}
int use_count(){
return *p_count;
}
};
int main(){
int *p = new int;
*p = 5;
Shared_ptr<int> s_ptr(p);//s_ptr指向了这块地址,pCount = 1
cout << "s_ptr count:" << s_ptr.use_count() << endl;
cout << "--------------------------------------------------" << endl;
Shared_ptr<int> s_ptr1 = s_ptr;//s_ptr1也指向了这块地址,pCount = 2
cout << "s_ptr1: count:" << s_ptr1.use_count() << endl;
cout << "--------------------------------------------------" << endl;
Shared_ptr<int> s_ptr2(p);//s_ptr2也指向了这块地址,不过重新创建了引用计数,pCount1 = 1
cout << "s_ptr2: count:" << s_ptr2.use_count() << endl;
cout << "--------------------------------------------------" << endl;
Shared_ptr<int> s_ptr3(s_ptr1);//s_ptr3指向s_ptr1,这样s_ptr,s_ptr1和s_ptr3都只向同一块内存空间,计数为3
cout << "s_ptr3: count:" << s_ptr3.use_count() << endl;
//s_ptr4指向的和s_ptr3相同,那么两者的引用计数都应该加1
cout << "--------------------------------------------------" << endl;
Shared_ptr<int> s_ptr4 = s_ptr3;
cout << "s_ptr3: count:" << s_ptr3.use_count() << endl;
cout << "s_ptr4: count:" << s_ptr4.use_count() << endl;
//system("pause");
return 0;
}
至于unique_ptr,因为从定义上,unique_ptr两个指针指向同一个地址,所以它的复制构造函数,和等号的重载都需要设置为不可用。参考实现如下:
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <map>
#include <unordered_map>
#include <set>
#include <unordered_set>
using namespace std;
template <typename T>
class unique_ptr{
public:
T *ptr;
unique_ptr(T *p){
ptr = p;
}
~unique_ptr(){
delete ptr;
}
T operator*(){
return *ptr;
}
// =delete表示禁止使用编译器默认生成的函数,也就是该函数不可用
unique_ptr(const unique_ptr<T> &p) = delete;
unique_ptr<T> &operator=(const unique_ptr<T> & p) = delete;
};
int main(){
int *p = new int(10);
unique_ptr<int> u_ptr = p;
cout << "u_ptr的值是: " << *u_ptr << endl;
cout << "--------------------------------------------------" << endl;
//unique_ptr<int> u_ptr2 = u_ptr;//报错,无法调用复制构造函数
//unique_ptr<int> u_ptr3(u_ptr);//报错,无法调用复制构造函数
system("pause");
return 0;
}