[BZOJ4499][线性函数][线段树]
题目大意:
小C最近在学习线性函数,线性函数可以表示为:f(x) = kx + b。现在小C面前有n个线性函数fi(x)=kix+bi ,他对这n个线性函数执行m次操作,每次可以:
1.M i K B 代表把第i个线性函数改为:fi(x)=kx+b 。
2.Q l r x 返回fr(fr-1(…fl(x))) mod 10^9+7 。
思路:
可以发现两个线性函数 f1(x)=k1x+b1 和 f2(x)=k2x+b2 ,求 f1(f2(x)) 就等于 (k1k2)x+(k1b2+b1) 。
也就是两个线性函数套在一起还是一个线性函数,而且这个东西满足结合律,求 f1(f2(f3(f4(x)))) 的时候可以先分别求 f1(f2(x)) 和 f3(f4(x)) ,然后套在一起就好了。
这样可以直接上线段树。
代码:
#include <bits/stdc++.h>
using namespace std;
const int Maxn = 200010;
typedef long long ll;
const ll Mod = 1000000007;
typedef pair<ll, ll> pl;
inline char get(void) {
static char buf[100000], *p1 = buf, *p2 = buf;
if (p1 == p2) {
p2 = (p1 = buf) + fread(buf, 1, 100000, stdin);
if (p1 == p2) return EOF;
}
return *p1++;
}
inline void read(ll &x) {
x = 0; bool f = 0; static char c;
for (; !(c >= '0' && c <= '9'); c = get()) if (c == '-') f = 1;
for (; c >= '0' && c <= '9'; x = x * 10 + c - '0', c = get()); if (f) x = -x;
}
inline void read(char &x) {
x = get();
while (!(x >= 'A' && x <= 'Z')) x = get();
}
inline void write(ll x) {
if (!x) return (void) (puts("0"));
if (x < 0) putchar('-'), x = -x;
static short st[20], top;
while (x) st[++top] = x % 10, x /= 10;
while (top) putchar('0' + st[top--]);
putchar('\n');
}
ll n, m, k[Maxn], b[Maxn];
pl t[Maxn << 2];
inline pl pushUp(pl l, pl r) {
pl ret;
ret.first = (l.first * r.first) % Mod;
ret.second = (r.second + (l.second * r.first % Mod)) % Mod;
return ret;
}
pl build(int o, int l, int r) {
if (l == r) {
return t[o] = make_pair(k[l], b[l]);
}
int mid = (l + r) >> 1;
return t[o] = pushUp(build(o << 1, l, mid), build(o << 1 | 1, mid + 1, r));
}
inline void modify(int o, int l, int r, int p, int K, int B) {
if (l == r) {
t[o] = make_pair(K, B);
return ;
}
int mid = (l + r) >> 1;
if (p <= mid) modify(o << 1, l, mid, p, K, B);
else modify(o << 1 | 1, mid + 1, r, p, K, B);
t[o] = pushUp(t[o << 1], t[o << 1 | 1]);
}
inline pl query(int o, int l, int r, int L, int R) {
if (l >= L && r <= R) {
return t[o];
}
int mid = (l + r) >> 1;
pl lpl = make_pair(1, 0), rpl = make_pair(1, 0);
if (mid >= L) lpl = query(o << 1, l, mid, L, R);
if (mid < R) rpl = query(o << 1 | 1, mid + 1, r, L, R);
return pushUp(lpl, rpl);
}
inline ll calc(pl p, ll x) {
return (p.first * x % Mod + p.second) % Mod;
}
int main(void) {
//freopen("in.txt", "r", stdin);
read(n), read(m);
for (int i = 1; i <= n; i++) read(k[i]), read(b[i]);
build(1, 1, n);
char op; ll x, y, o;
for (int i = 1; i <= m; i++) {
read(op);
if (op == 'M') {
read(o), read(x), read(y);
modify(1, 1, n, o, x, y);
} else {
read(x), read(y), read(o);
write(calc(query(1, 1, n, x, y), o));
}
}
return 0;
}
完。
By g1n0st