给你一个已经 排好序 的整数数组 nums
和整数 a
、 b
、 c
。对于数组中的每一个元素 nums[i]
,计算函数值 f(x) = ax2 + bx + c
,请 按升序返回数组 。
示例 1:
输入: nums = [-4,-2,2,4], a = 1, b = 3, c = 5 输出: [3,9,15,33]
示例 2:
输入: nums = [-4,-2,2,4], a = -1, b = 3, c = 5 输出: [-23,-5,1,7]
提示:
1 <= nums.length <= 200
-100 <= nums[i], a, b, c <= 100
nums
按照 升序排列
提示 1
x^2 + x will form a parabola.
提示 2
Parameter A in: A * x^2 + B * x + C dictates the shape of the parabola.
Positive A means the parabola remains concave (high-low-high), but negative A inverts the parabola to be convex (low-high-low).
解法1:排序
Java版:
class Solution {
public int[] sortTransformedArray(int[] nums, int a, int b, int c) {
int n = nums.length;
for (int i = 0; i < n; i++) {
nums[i] = a * nums[i] * nums[i] + b * nums[i] + c;
}
Arrays.sort(nums);
return nums;
}
}
Python3版:
class Solution:
def sortTransformedArray(self, nums: List[int], a: int, b: int, c: int) -> List[int]:
nums = [a * x * x + b * x + c for x in nums]
nums.sort()
return nums
复杂度分析
- 时间复杂度:O(nlogn),n为数组的长度。排序所用的时间复杂度为O(nlogn),一次for循环的时间复杂度为O(n),总的时间复杂度为O(nlogn + n) = O(nlogn)。
- 空间复杂度:O(logn),排序所用的空间复杂度为O(logn)。
解法2:数学法
Java版:
class Solution {
public int[] sortTransformedArray(int[] nums, int a, int b, int c) {
int n = nums.length;
// 对称轴
double mid = a == 0 ? 0 : (0.0 - b) / (2 * a);
// 函数在数组内单调递增
if ((a == 0 && b >= 0) || (a > 0 && mid <= (double) nums[0]) || (a < 0 && mid >= (double) nums[n - 1])) {
for (int i = 0; i < n; i++) {
nums[i] = a* nums[i] * nums[i] + b * nums[i] + c;
}
return nums;
}
// 函数在数组内单调递减
if ((a == 0 && b < 0) || (a < 0 && mid <= (double) nums[0]) || (a > 0 && mid >= (double) nums[n - 1])) {
int l = 0;
int r = n -1;
while (l <= r) {
nums[l] = b * nums[l] + c;
nums[r] = b * nums[r] + c;
int temp = nums[r];
nums[r] = nums[l];
nums[l] = temp;
l++;
r--;
}
return nums;
}
int[] ans = new int[n];
int k = 0;
// 抛物线开口向上,有极小值,越靠近对称轴,值越小
if (a > 0) {
int l = binarySearch(nums, mid);
int r = l + 1;
while (l >= 0 || r < n) {
if (l < 0) {
while (r < n) {
ans[k++] = a * nums[r] * nums[r] + b * nums[r] + c;
r++;
}
} else if (r >= n) {
while (l >= 0) {
ans[k++] = a * nums[l] * nums[l] + b * nums[l] + c;
l--;
}
} else {
if (mid - (double)nums[l] < (double)nums[r] - mid) {
ans[k++] = a * nums[l] * nums[l] + b * nums[l] + c;
l--;
} else {
ans[k++] = a * nums[r] * nums[r] + b * nums[r] + c;
r++;
}
}
}
return ans;
}
// 抛物线开口向下,有极大值,越远离对称轴,值越小
if (a < 0) {
int l = 0;
int r = n - 1;
while (l <= r) {
if (mid - (double)nums[l] > (double)nums[r] - mid) {
ans[k++] = a * nums[l] * nums[l] + b * nums[l] + c;
l++;
} else {
ans[k++] = a * nums[r] * nums[r] + b * nums[r] + c;
r--;
}
}
return ans;
}
return nums;
}
// 找到最大的<=target的值的下标
private int binarySearch(int[] nums, double target) {
int l = 0;
int r = nums.length - 1;
while (l <= r) {
int mid = l + (r - l) / 2;
if ((double) nums[mid] <= target) {
l = mid + 1;
} else {
r = mid - 1;
}
}
return r;
}
}
Python3版:
class Solution:
def sortTransformedArray(self, nums: List[int], a: int, b: int, c: int) -> List[int]:
n = len(nums)
mid = (0 - b) / (2 * a) if a != 0 else 0
# 单调递增
if (a == 0 and b >= 0) or (a > 0 and mid <= nums[0]) or (a < 0 and mid >= nums[n - 1]):
for i in range(n):
nums[i] = a * nums[i] * nums[i] + b * nums[i] + c
return nums
# 单调递减
if (a == 0 and b < 0) or (a > 0 and mid >= nums[n - 1]) or (a < 0 and mid <= nums[0]):
l = 0
r = n - 1
while l <= r:
nums[l] = a * nums[l] * nums[l] + b * nums[l] + c
nums[r] = a * nums[r] * nums[r] + b * nums[r] + c
nums[l], nums[r] = nums[r], nums[l]
l += 1
r -= 1
return nums
ans = []
if a > 0:
l = self.binarySearch(nums, mid)
r = l + 1
while l >= 0 or r < n:
if l < 0:
while r < n:
ans.append( a * nums[r] * nums[r] + b * nums[r] + c )
r += 1
elif r >= n:
while l >= 0:
ans.append( a * nums[l] * nums[l] + b * nums[l] + c )
l -= 1
elif mid - nums[l] < nums[r] - mid:
ans.append( a * nums[l] * nums[l] + b * nums[l] + c )
l -= 1
else:
ans.append( a * nums[r] * nums[r] + b * nums[r] + c )
r += 1
return ans
if a < 0:
l = 0
r = n - 1
while l <= r:
if mid - nums[l] > nums[r] - mid:
ans.append( a * nums[l] * nums[l] + b * nums[l] + c )
l += 1
else:
ans.append( a * nums[r] * nums[r] + b * nums[r] + c )
r -= 1
return ans
def binarySearch(self, nums: List[int], target: float) -> int:
l = 0
r = len(nums)
while l <= r:
mid = l + (r - l) // 2
if nums[mid] <= target:
l = mid + 1
else:
r = mid - 1
return r
复杂度分析
- 时间复杂度:O(n),n为数组的长度。
- 空间复杂度:O(n),n为数组的长度。