输入 n(≤2e5) 和 q(≤2e5)。 初始有一个长为 n 的字符串 s,所有字符都是 1,s 的下标从 1 开始。 然后输入 q
个替换操作,每个操作输入 L,R (1≤L≤R≤n) 和 d (1≤d≤9)。 你需要把 s 的 [L,R] 内的所有字符替换为 d。
对每个操作,把替换后的 s 看成一个十进制数,输出这个数模 998244353 的结果。
输入
8 5
3 6 2
1 4 7
3 8 3
2 2 2
4 5 1
输出
11222211
77772211
77333333
72333333
72311333
输入
200000 1
123 456 7
输出
641437905
#include <bits/stdc++.h>
using namespace std;
/**
* OTD 解决快速推平一段区间的问题
* 1. 将一段连续的相同的元素集合整合为,(L, R, val) 这样一个三元组, set维护
* 2. set大小快速下降,趋于logn
* 3. 复杂度近似 mlogn
*/
struct Node {
int l, r;
mutable int v; // mutable 多变的这样const方法就可以修改
Node(int _l, int _r = -1, int _v = 0) : l(_l), r(_r), v(_v) {}
bool operator < (const Node &that) const {
return l < that.l;
}
};
const int mod = 998244353, N = 2e5 + 10;
int sums[N], res = 0; // 11111111
set<Node> s;
// 将pos所在三元组分成(l, pos - 1, val), (pos, r, val), 并返回pos所在的区间的迭代器
set<Node>::iterator split(int pos) {
auto it = s.lower_bound(Node(pos)); // 找到第一个去区间的l大于等于pos的区间
if (it != s.end() && it->l == pos) return it; // pos本就是一个区间的头部
--it;
if (pos > it->r) return s.end(); // pos所在的区间不存在
int l = it->l, r = it->r, v = it->v;
s.erase(it);
s.insert(Node(l, pos - 1, v));
return s.insert(Node(pos, r, v)).first; // pair<iterator, bool> insert(const value_type& val)
}
void add(int l, int r, int val) { // 区间加
split(l);
auto itr = split(r + 1), itl = split(l);
for (; itl != itr; ++itl) itl->v += val;
}
void assign(int l, int r, int val) { // 推平区间
split(l); // split(r + 1)可能会导致当前迭代器被删除,所以先获取迭代器
auto itr = split(r + 1), itl = split(l);
s.erase(itl, itr); // 删除区间
s.insert(Node(l, r, val));
}
int calc(int L, int R, int d) {
split(L);
auto itr = split(R + 1), itl = split(L);
for (; itl != itr; ++itl) {
int l = itl->l, r = itl->r, v = itl->v;
int x = (sums[r] + mod - sums[l - 1]) % mod; // 1111...000000000...
int y = (d + mod - v) % mod; // diff
res = (res + 1LL * x * y) % mod;
}
assign(L, R, d);
return res;
}
int main()
{
freopen("in.txt", "r", stdin);
cin.tie(nullptr)->sync_with_stdio(false);
int n, m;
cin >> n >> m;
s.clear();
s.insert(Node(1, n, 1));
for (int i = 1; i <= n; ++i) sums[i] = (sums[i - 1] * 10LL + 1) % mod;
res = sums[n];
while (m--) {
int l, r, d;
cin >> l >> r >> d;
cout << calc(n - r + 1, n - l + 1, d) << '\n';
}
return 0;
}