题目描述
解法:哈希表+单链表+堆(C++)
其实设计类问题都不难解,主要是想明白采用什么样的数据结构
int global_Time = 0; // 发表推文的时间
// 推文类
class Tweet{
public:
int id;
int time;
Tweet* next;
Tweet(int id){
this->id = id;
this->time = global_Time++;
next = nullptr;
}
};
// 用户类
class User{
public:
int id;
Tweet* tweet; // 该用户发送的推文
unordered_set<int> follows; // 该用户关注的用户
User(int id){
this->id = id;
tweet = nullptr;
}
void follow(int followed){
if(followed==id) return;
follows.insert(followed);
}
void unfollow(int followed){
if(!follows.count(followed) || followed==id) return;
follows.erase(followed);
}
void post(int tweetId){
Tweet* newTweet = new Tweet(tweetId);
newTweet->next = tweet;
tweet = newTweet;
}
};
class Twitter{
private:
unordered_map<int, User*> user_map;
bool contain(int id){
return user_map.find(id) != user_map.end();
}
public:
Twitter(){
user_map.clear();
}
void postTweet(int userId, int tweetId){
if(!contain(userId))
user_map[userId] = new User(userId);
user_map[userId]->post(tweetId);
}
vector<int> getNewsFeed(int userId){
if(!contain(userId)) return {};
struct cmp{
bool operator()(const Tweet* a, const Tweet* b){
return a->time < b->time;
}
};
// 构造大顶堆,时间最大的排在上面
priority_queue<Tweet*, vector<Tweet*>, cmp> q;
// 自己的推文链表
if(user_map[userId]->tweet)
q.push(user_map[userId]->tweet);
// 关注的推文链表
for(int followeeId: user_map[userId]->follows){
if(!contain(followeeId)) continue;
Tweet* head = user_map[followeeId]->tweet;
if(head==nullptr) continue;
q.push(head);
}
vector<int> res;
while(!q.empty()){
Tweet* t = q.top();
q.pop();
res.push_back(t->id);
if(res.size()==10) return res;
if(t->next) q.push(t->next);
}
return res;
}
void follow(int followerId, int followeeId){
if(!contain(followerId))
user_map[followerId] = new User(followerId);
user_map[followerId]->follow(followeeId);
}
void unfollow(int followerId, int followeeId){
if(!contain(followerId)) return;
user_map[followerId]->unfollow(followeeId);
}
};
当然,如果在Python中,链表很容易被列表替代,下面是一个Python的版本:
说明两点:
self.tweets = defaultdict(lambda: None)
如果直接申明为self.tweets = defaultdict(None)
,会报KeyError的错,所以用了个匿名函数;def __lt__(self, other): return self.timestamp > other.timestamp
是重载函数,Python里叫魔法函数。在 Python2 里是__cmp__
方法,Python3里换成了一系列方法__le__
、__lt__
等等。
from collections import defaultdict
import heapq
class Tweet:
def __init__(self, tweetId, timestamp):
self.id = tweetId
self.timestamp = timestamp
self.next = None
def __lt__(self, other):
return self.timestamp > other.timestamp
class Twitter:
def __init__(self):
self.followings = defaultdict(set)
self.tweets = defaultdict(lambda: None)
self.timestamp = 0
def postTweet(self, userId: int, tweetId: int) -> None:
self.timestamp += 1
tweet = Tweet(tweetId, self.timestamp)
tweet.next = self.tweets[userId]
self.tweets[userId] = tweet
def getNewsFeed(self, userId: int) -> List[int]:
tweets = []
heap = []
tweet = self.tweets[userId]
if tweet:
heap.append(tweet)
for user in self.followings[userId]:
tweet = self.tweets[user]
if tweet:
heap.append(tweet)
heapq.heapify(heap)
while heap and len(tweets) < 10:
head = heapq.heappop(heap)
tweets.append(head.id)
if head.next:
heapq.heappush(heap, head.next)
return tweets
def follow(self, followerId: int, followeeId: int) -> None:
if followerId == followeeId:
return
self.followings[followerId].add(followeeId)
def unfollow(self, followerId: int, followeeId: int) -> None:
if followerId == followeeId:
return
self.followings[followerId].discard(followeeId)