一个和前缀和思想非常类似的算法技巧「差分数组」,差分数组的主要适用场景是频繁对原始数组的某个区间的元素进行增减。
比如说,我给你输入一个数组 nums
,然后又要求给区间 nums[2..6]
全部加 1,再给 nums[3..9]
全部减 3,再给 nums[0..4]
全部加 2,再给…
一通操作猛如虎,然后问你,最后 nums
数组的值是什么?
常规的思路很容易,你让我给区间 nums[i..j]
加上 val
,那我就一个 for 循环给它们都加上呗,还能咋样?这种思路的时间复杂度是 O(N),由于这个场景下对 nums
的修改非常频繁,所以效率会很低下。
这里就需要差分数组的技巧,类似前缀和技巧构造的 prefix
数组,我们先对 nums
数组构造一个 diff
差分数组,diff[i]
就是 nums[i]
和 nums[i-1]
之差:
int[] diff = new int[nums.length];
// 构造差分数组
diff[0] = nums[0];
for (int i = 1; i < nums.length; i++) {
diff[i] = nums[i] - nums[i - 1];
}
通过这个 diff
差分数组是可以反推出原始数组 nums
的,代码逻辑如下:
int[] res = new int[diff.length];
// 根据差分数组构造结果数组
res[0] = diff[0];
for (int i = 1; i < diff.length; i++) {
res[i] = res[i - 1] + diff[i];
}
这样构造差分数组 diff
,就可以快速进行区间增减的操作,如果你想对区间 nums[i..j]
的元素全部加 3,那么只需要让 diff[i] += 3
,然后再让 diff[j+1] -= 3
即可:
原理很简单,回想 diff
数组反推 nums
数组的过程,diff[i] += 3
意味着给 nums[i..]
所有的元素都加了 3,然后 diff[j+1] -= 3
又意味着对于 nums[j+1..]
所有元素再减 3,那综合起来,是不是就是对 nums[i..j]
中的所有元素都加 3 了?
只要花费 O(1) 的时间修改 diff
数组,就相当于给 nums
的整个区间做了修改。多次修改 diff
,然后通过 diff
数组反推,即可得到 nums
修改后的结果。
现在我们把差分数组抽象成一个类,包含 increment
方法和 result
方法
// 差分数组工具类
class Difference {
// 差分数组
private int[] diff;
/* 输入一个初始数组,区间操作将在这个数组上进行 */
public Difference(int[] nums) {
assert nums.length > 0;
diff = new int[nums.length];
// 根据初始数组构造差分数组
diff[0] = nums[0];
for (int i = 1; i < nums.length; i++) {
diff[i] = nums[i] - nums[i - 1];
}
}
/* 给闭区间 [i, j] 增加 val(可以是负数)*/
public void increment(int i, int j, int val) {
diff[i] += val;
if (j + 1 < diff.length) {
diff[j + 1] -= val;
}
}
/* 返回结果数组 */
public int[] result() {
int[] res = new int[diff.length];
// 根据差分数组构造结果数组
res[0] = diff[0];
for (int i = 1; i < diff.length; i++) {
res[i] = res[i - 1] + diff[i];
}
return res;
}
}
算法实践
那么我们直接复用刚才实现的 Difference
类就能把这道题解决掉
int[] getModifiedArray(int length, int[][] updates) {
// nums 初始化为全 0
int[] nums = new int[length];
// 构造差分解法
Difference df = new Difference(nums);
for (int[] update : updates) {
int i = update[0];
int j = update[1];
int val = update[2];
df.increment(i, j, val);
}
return df.result();
}
int[] corpFlightBookings(int[][] bookings, int n) {
// nums 初始化为全 0
int[] nums = new int[n];
// 构造差分解法
Difference df = new Difference(nums);
for (int[] booking : bookings) {
// 注意转成数组索引要减一哦
int i = booking[0] - 1;
int j = booking[1] - 1;
int val = booking[2];
// 对区间 nums[i..j] 增加 val
df.increment(i, j, val);
}
// 返回最终的结果数组
return df.result();
}
你是一个开公交车的司机,公交车的最大载客量为 capacity
,沿途要经过若干车站,给你一份乘客行程表 int[][] trips
,其中 trips[i] = [num, start, end]
代表着有 num
个旅客要从站点 start
上车,到站点 end
下车,请你计算是否能够一次把所有旅客运送完毕(不能超过最大载客量 capacity
)
相信你已经能够联想到差分数组技巧了:trips[i]
代表着一组区间操作,旅客的上车和下车就相当于数组的区间加减;只要结果数组中的元素都小于 capacity
,就说明可以不超载运输所有旅客。
但问题是,差分数组的长度(车站的个数)应该是多少呢?题目没有直接给,但给出了数据取值范围:
0 <= trips[i][1] < trips[i][2] <= 1000
boolean carPooling(int[][] trips, int capacity) {
// 最多有 1001 个车站
int[] nums = new int[1001];
// 构造差分解法
Difference df = new Difference(nums);
for (int[] trip : trips) {
// 乘客数量
int val = trip[0];
// 第 trip[1] 站乘客上车
int i = trip[1];
// 第 trip[2] 站乘客已经下车,
// 即乘客在车上的区间是 [trip[1], trip[2] - 1]
int j = trip[2] - 1;
// 进行区间操作
df.increment(i, j, val);
}
int[] res = df.result();
// 客车自始至终都不应该超载
for (int i = 0; i < res.length; i++) {
if (capacity < res[i]) {
return false;
}
}
return true;
}