Sort a linked list in O(n log n) time using constant space complexity.
java
/**
* Definition for singly-linked list.
* public class ListNode {
* int val;
* ListNode next;
* ListNode(int x) { val = x; }
* }
*/
class Solution {
public ListNode sortList(ListNode head) {
if (head == null || head.next == null) {
return head;
}
ListNode leftDummy = new ListNode(0);
ListNode midDummy = new ListNode(0);
ListNode rightDummy = new ListNode(0);
ListNode left = leftDummy;
ListNode mid = midDummy;
ListNode right = rightDummy;
ListNode node = getMid(head);
while (head != null) {
if (head.val < node.val) {
left.next = head;
left = left.next;
} else if (head.val > node.val) {
right.next = head;
right = right.next;
} else {
mid.next = head;
mid = mid.next;
}
head = head.next;
}
left.next = null;
mid.next = null;
right.next = null;
ListNode start = sortList(leftDummy.next);
ListNode end = sortList(rightDummy.next);
ListNode result = combine(start, midDummy.next, end);
return result;
}
private ListNode getMid(ListNode head) {
if (head == null) {
return head;
}
ListNode slow = head;
ListNode fast = head.next;
while (fast != null && fast.next != null) {
slow = slow.next;
fast = fast.next.next;
}
return slow;
}
private ListNode combine(ListNode head, ListNode mid, ListNode end) {
ListNode dummy = new ListNode(0);
ListNode pre = dummy;
if (head != null) {
pre.next = head;
}
pre = getTail(pre);
if (mid != null) {
pre.next = mid;
}
pre = getTail(pre);
if (end != null) {
pre.next = end;
}
return dummy.next;
}
private ListNode getTail(ListNode head) {
if (head == null) {
return head;
}
while (head.next != null) {
head = head.next;
}
return head;
}
}
python
"""
Definition of ListNode
class ListNode(object):
def __init__(self, val, next=None):
self.val = val
self.next = next
"""
class Solution:
"""
@param: head: The head of linked list.
@return: You should return the head of the sorted linked list, using constant space complexity.
"""
def sortList(self, head):
# write your code here
if head == None or head.next == None:
return head
leftDummy, midDummy, rightDummy = ListNode(0), ListNode(0), ListNode(0)
left, mid, right = leftDummy, midDummy, rightDummy
node = self.getMid(head)
while head is not None:
if head.val < node.val:
left.next = head
left = left.next
elif head.val > node.val:
right.next = head
right = right.next
else:
mid.next = head
mid = mid.next
head = head.next
left.next, right.next, mid.next = None, None, None
start = self.sortList(leftDummy.next)
end = self.sortList(rightDummy.next)
result = self.combine(start, midDummy.next, end)
return result
def getMid(self, head):
if head == None or head.next == None:
return head
slow, fast = head, head.next
while fast != None and fast.next != None:
slow = slow.next
fast = fast.next.next
return slow
def combine(self, head, mid, end):
dummy = ListNode(0)
pre = dummy
if head is not None:
pre.next = head
pre = self.getTail(pre)
if mid is not None:
pre.next = mid
pre = self.getTail(pre)
if end is not None:
pre.next = end
return dummy.next
def getTail(self, head):
if head == None or head.next == None:
return head
while head.next != None:
head = head.next
return head