豪华升级版同余类最短路……
主要写几个小trick:
\(1.O(nm)\)实现同余类最短路:
设某一条边长度为\(x\),那么我们选择一个点,在同余类上不断跳\(x\),可以形成一个环。
显然只有在同一个环上的两点之间才可能通过\(x\)进行转移。我们选择环上答案最小的点,它一定不会在当次更新时被更新答案,所以直接从这个点开始依次遍历环上的所有点,每一个点尝试从前面的一个点更新答案。
\(2.\)将\(\mod n\)的同余类最短路变为\(\mod d\)的同余类最短路:
令新的同余类最短路为\(g_x\),原同余类最短路为\(f_x\),那么首先令\(g_{f_i \mod d} \leftarrow f_i\),但是可能会有一些\(g\)没有被正确更新。在\(f_x\)中实际上还有默认的长度为\(n\)的边,那么在\(g\)中用长度为\(n\)的边在\(g\)上更新一次同余类最短路就可以得到正确的答案了。
\(3.\)更新\(border\)长度为等差数列的一段数的操作过程:
设这一个等差数列的首项为\(x\),公差为\(y\),有\(t+1\)项,先将原最短路变为\(\mod x\)的同余类最短路,那么对于每一个环上的点,可以从前面\(t\)个点进行转移,代价为距离\(\times y + x\),本质是一个多重背包,使用单调队列优化转移。
PS:UOJ EX5好毒瘤啊……
#include<bits/stdc++.h>
#define INF 0x7fffffff
#define ll long long
#define int long long
#define PIL pair < int , ll >
#define st first
#define nd second
//This code is written by Itst
using namespace std;
const int MAXN = 5e5 + 10;
char s[MAXN];
int nxt[MAXN] , N , cur;
deque < PIL > q;
ll dis[MAXN] , pot[MAXN] , W;
bool vis[MAXN];
inline void calc(int base , int t , int K){
memset(vis , 0 , sizeof(bool) * cur);
for(int i = 0 ; i < cur ; ++i)
if(!vis[i]){
q.clear();
vis[i] = 1;
int p = (i + t) % cur , minN = i , cnt = 0;
while(!vis[p]){
if(dis[minN] > dis[p])
minN = p;
vis[p] = 1;
p = (t + p) % cur;
}
if(dis[minN] > W)
continue;
p = (minN + t) % cur;
q.push_back(PIL(0 , dis[minN]));
while(p != minN){
if(++cnt - q.front().st > K)
q.pop_front();
dis[p] = min(dis[p] , q.front().nd + cnt * t + base);
while(!q.empty() && q.back().nd >= dis[p] - cnt * t)
q.pop_back();
q.push_back(PIL(cnt , dis[p] - cnt * t));
p = (t + p) % cur;
}
}
}
inline void change(int to){
memset(pot , 0x3f , sizeof(pot));
for(int i = 0 ; i < cur ; ++i)
pot[dis[i] % to] = min(pot[dis[i] % to] , dis[i]);
int t = cur;
cur = to;
memcpy(dis , pot , sizeof(ll) * cur);
calc(0 , t , 1);
}
signed main(){
int T;
nxt[0] = -1;
for(scanf("%lld" , &T) ; T ; --T){
memset(dis , 0x3f , sizeof(dis));
scanf("%lld %lld %s" , &N , &W , s + 1);
if(N > W){
puts("0");
continue;
}
W -= N;
dis[0] = 0;
cur = N;
for(int i = 1 ; i <= N ; ++i){
int t = nxt[i - 1];
while(t != -1 && s[t + 1] != s[i])
t = nxt[t];
nxt[i] = t + 1;
}
int t = nxt[N] , L = N;
while(t > 0){
if(t <= (L >> 1))
calc(0 , N - t , 1);
else{
int l = t % (L - t) , cnt = t - nxt[t];
change(N - t);
calc(cur , L - t , (t - l) / (L - t));
while(t && t - nxt[t] == cnt)
t = nxt[t];
}
L = t;
t = nxt[t];
}
ll ans = 0;
for(int i = 0 ; i < cur ; ++i)
if(W >= dis[i])
ans += (W - dis[i]) / cur + 1;
cout << ans << '\n';
}
return 0;
}