【题目描述】
给定一个链表,再给定一个整数 pivot,请将链表调整为左部分都是值小于 pivot 的节点,中间部分都是值等于 pivot 的节点, 右边部分都是大于 pivot 的节点。
除此之外,对调整后的节点顺序没有更多要求。
【输入描述】
第一行两个整数 n 和 pivot,n 表示链表的长度。
第二行 n 个整数表示链表的节点值val。
【输出描述】
请在给定的函数内返回链表的头指针。
【示例1】
输入
5 3
9 0 4 5 1
输出
1 0 4 5 9
【备注】
1<= n <= 1000000
-1000000<= val, pivot <= 1000000
【代码实现 - CPP版】
# include <bits/stdc++.h>
using namespace std;
struct list_node{
int val;
struct list_node *next;
};
int pivot;
// create a link list
list_node* input_list(void)
{
int n, val;
list_node * phead = new list_node();
list_node * cur_pnode = phead;
cin >> n >> pivot;
for (int i = 1; i <= n; ++i) {
cin >> val;
if (i == 1) {
cur_pnode->val = val;
cur_pnode->next = NULL;
}
else {
list_node * new_pnode = new list_node();
new_pnode->val = val;
new_pnode->next = NULL;
cur_pnode->next = new_pnode;
cur_pnode = new_pnode;
}
}
return phead;
}
// print all elements of a link list
void print_list(list_node *head) {
while(NULL != head) {
cout << head->val << " ";
head = head->next;
}
}
// array partition function
void arr_partition(vector<list_node*>& vec, const int& pivot) {
int left = -1; // left section [0,left]
int index = 0; // middle section [left+1, index]
int right = vec.size(); // right section [right, vec.size()-1]
while(index != right) {
if(vec[index]->val < pivot) { // extend left section to index
left++;
swap(vec[index]->val, vec[left]->val);
index++;
} else if(vec[index]->val > pivot){ // extend right section to index
right--;
swap(vec[index]->val, vec[right]->val);
} else { // extend the middle section
index++;
}
}
}
void list_partition(list_node *head, int pivot)
{
// scan the link list to obtain its length
list_node *cur = head;
int len = 0;
int i = 0;
while(NULL != cur) {
len++;
cur = cur->next;
}
// record link list elements to a vector
vector<list_node*> nodes(len, NULL);
cur = head;
for(i=0; i<len; i++) {
nodes[i] = cur;
cur = cur->next;
}
// partition the vector
arr_partition(nodes, pivot);
//
for(i=1; i<len; i++) {
nodes[i-1]->next = nodes[i];
}
// set last element nodes[len-1] = NULL
nodes[i-1]->next = NULL;
// assign pointer nodes[0] to point head
head = nodes[0];
}
int main ()
{
list_node * head = input_list();
list_partition(head, pivot);
print_list(head);
return 0;
}