树上莫队算法:
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include <stdbool.h>
typedef long long int64;
typedef int (*cmp_t) (const void *, const void *);
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define swap(x, y) ({ int _ = x; x = y; y = _; })
void get(int *x)
{
char c = getchar();
for (; '0' > c || c > '9'; c = getchar());
*x = c - '0', c = getchar();
for (; '0' <= c && c <= '9'; c = getchar()) *x *= 10, *x += c - '0';
}
typedef struct edge { int t; struct edge *n; } edge;
typedef struct elem { int x, y, t; struct elem *n; } elem;
typedef struct info { int t, f; struct info *n; } info;
typedef int arr32[200100];
edge mem[200005], *adjc = mem, *adj[100005];
elem memo[100005], *me = memo, *modify[100005], *que[100005], *ask[60][60];
info Mem[200005], *memt = Mem, *list[100005];
arr32 v, w, c, cc, ufs, br, ll, rr, lca, bel, sum, type, xx, yy;
bool appr[100005]; int64 ans, anss[100005];
int n, m, Q, B, tot, mt, qt, bt, hehe, haha;
int find(int x) { return ufs[x] == x ? x : (ufs[x] = find(ufs[x])); }
void dfs(int u, int fa)
{
edge *e;
br[ll[u] = ++tot] = u;
for (e = adj[u]; e; e = e->n)
if (e->t != fa) dfs(e->t, u);
br[rr[u] = ++tot] = u;
}
void tarjan(int u, int fa)
{
edge *e; info *i; ufs[u] = u;
for (e = adj[u]; e; e = e->n)
if (e->t != fa)
tarjan(e->t, u), ufs[e->t] = u;
for (i = list[u]; i; i = i->n)
if (ufs[i->t])
lca[i->f] = find(i->t);
}
void trans(int p)
{
if (appr[p])
appr[p] = 0, ans -= (int64) v[c[p]] * w[sum[c[p]]--];
else
appr[p] = 1, ans += (int64) v[c[p]] * w[++sum[c[p]]];
}
int main()
{
freopen("park.in", "r", stdin);
freopen("park.out", "w", stdout);
int i, j, k;
get(&n), get(&m), get(&Q);
for (i = 1; i <= m; ++i)
get(v + i);
for (i = 1; i <= n; ++i)
get(w + i);
for (i = 1; i < n; ++i) {
int a, b;
get(&a), get(&b);
*adjc = (edge) {b, adj[a]}, adj[a] = adjc++;
*adjc = (edge) {a, adj[b]}, adj[b] = adjc++;
}
for (i = 1; i <= n; ++i)
get(c + i), cc[i] = c[i];
/* get brackets sequence */
dfs(1, 0);
/* n ^ (2 / 3) */
for (B = 1; B * B * B < tot; ++B); B *= B;
for (i = j = bt = 1; i <= tot; ++i, ++j) {
if (j > B) j = 1, ++bt;
bel[i] = bt;
}
for (i = 0; i < Q; ++i) {
int x, y;
get(type + i), get(&x), get(&y);
if (!type[i])
*me = (elem) {x, y, i}, modify[mt++] = me++;
else {
/* for LCA queries */
*memt = (info) {y, me - memo, list[x]}, list[x] = memt++;
*memt = (info) {x, me - memo, list[y]}, list[y] = memt++;
/* asked intervals */
if (ll[x] > ll[y]) swap(x, y);
xx[i] = x, yy[i] = y;
x = rr[x] < rr[y] ? rr[x] : ll[x]; y = ll[y];
*me = (elem) {x, y, i}; que[qt++] = me++;
}
}
/* divide queries into groups */
for (i = qt; i--; ) {
elem *j = que[i];
j->n = ask[bel[j->x]][bel[j->y]], ask[bel[j->x]][bel[j->y]] = j;
}
/* LCA */
tarjan(1, 0);
/* answer */
for (i = 1; i <= bt; ++i)
for (j = i; j <= bt; ++j)
if (ask[i][j]) {
int l = (i - 1) * B + 1, r = l, f;
elem *e, *d;
memset(sum, 0, (m + 1) << 2);
memset(appr, 0, (n + 1));
memcpy(c, cc, (n + 1) << 2);
ans = k = 0; trans(br[l]);
for (e = ask[i][j]; e; e = e->n) {
if (l < e->x) for (; l != e->x; ++l) trans(br[l]);
else if (l > e->x) do --l, trans(br[l]); while (l != e->x);
if (r > e->y) for (; r != e->y; --r) trans(br[r]);
else if (r < e->y) do ++r, trans(br[r]); while (r != e->y);
for (; k < mt && modify[k]->t < e->t; ++k) {
d = modify[k];
bool flag = (l <= ll[d->x] && ll[d->x] <= r) ^ (l <= rr[d->x] && rr[d->x] <= r);
if (flag) trans(d->x);
c[d->x] = d->y;
if (flag) trans(d->x);
}
anss[e->t] = ans;
if (f = lca[e - memo], f != xx[e->t] && f != yy[e->t])
anss[e->t] += (int64) v[c[f]] * w[sum[c[f]] + 1];
}
}
for (i = 0; i < Q; ++i)
if (type[i])
printf(fmt64 "\n", anss[i]);
return 0;
}
70分树上莫队算法:
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <ctime>
#include <algorithm>
#define uns unsigned
#define int64 long long
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define REP(i, n) for (i = 0; i < (n); ++i)
#define maxn 200005
using namespace std;
int n, m, Q, Br;
struct edge { int t; edge *n; } edges[maxn * 2], *adj = edges, *lst[maxn];
typedef int array[maxn];
array v, w, c, x, y, p, xx, yy, br, ll, rr, hehe, ufs, lca, sum, pa;
struct elem { int t, f; elem *n; } elems[maxn * 2], *eptr = elems, *stk[maxn];
int64 result, ans[maxn]; bool appr[maxn];
void link(int a, int b)
{
*adj = (edge){b, lst[a]}, lst[a] = adj++;
*adj = (edge){a, lst[b]}, lst[b] = adj++;
}
void dfs(int u, int fa)
{
br[ll[u] = ++Br] = u;
for (edge *e = lst[u]; e; e = e->n)
if (e->t != fa) dfs(e->t, u);
br[rr[u] = ++Br] = u;
}
void add(int a, int b, int c)
{
*eptr = (elem){b, c, stk[a]}, stk[a] = eptr++;
*eptr = (elem){a, c, stk[b]}, stk[b] = eptr++;
}
int find(int x) { return ufs[x] == x ? x : ufs[x] = find(ufs[x]); }
void tarjan(int u, int fa)
{
ufs[u] = u;
for (elem *e = stk[u]; e; e = e->n)
if (ufs[e->t])
lca[e->f] = find(e->t);
for (edge *e = lst[u]; e; e = e->n)
if (e->t != fa)
tarjan(e->t, u), ufs[e->t] = u;
}
bool nicer(int a, int b)
{
return hehe[xx[a]] < hehe[xx[b]] || (hehe[xx[a]] == hehe[xx[b]] && yy[a] < yy[b]);
}
void trans(int x)
{
if (appr[x]) {
appr[x] = 0;
result -= (int64)v[c[x]] * w[sum[c[x]]--];
} else {
appr[x] = 1;
result += (int64)v[c[x]] * w[++sum[c[x]]];
}
}
int64 ask(int x, int y)
{
int i, j, f; int64 res = 0;
static int Mark, tot;
static array mark, a;
++Mark, tot = 0;
for (i = x; i; i = pa[i])
mark[i] = Mark;
for (f = y; mark[f] < Mark; f = pa[f])
a[++tot] = c[f];
a[++tot] = c[f];
for (i = x; i != f; i = pa[i])
a[++tot] = c[i];
sort(a + 1, a + tot + 1);
for (i = 1; i <= tot; )
for (j = 0, f = a[i]; a[i] == f && i <= tot; ++i)
res += (int64)v[f] * w[++j];
return res;
}
void dfs1(int u, int fa)
{
edge *e; pa[u] = fa;
for (e = lst[u]; e; e = e->n)
if (e->t != fa) dfs1(e->t, u);
}
int main()
{
freopen("park.in", "r", stdin);
freopen("park.out", "w", stdout);
int i, j, k;
scanf("%d%d%d", &n, &m, &Q);
for (i = 1; i <= m; ++i)
scanf("%d", v + i);
for (i = 1; i <= n; ++i)
scanf("%d", w + i);
for (i = 1; i < n; ++i) {
int a, b;
scanf("%d%d", &a, &b);
link(a, b);
}
for (i = 1; i <= n; ++i)
scanf("%d", c + i);
if (n <= 20000 && m <= 20000) {
for (dfs1(1, 0); Q--; ) {
int t, x, y;
scanf("%d%d%d", &t, &x, &y);
t ? printf(fmt64"\n", ask(x, y)) : c[x] = y;
}
exit(0);
}
dfs(1, 0);
for (i = 1, j = k = 0; i <= Br; ++i) {
if (i > j) j += 400, ++k;
hehe[i] = k;
}
REP(i, Q) {
int a, b;
scanf("%d%d%d", &j, &a, &b);
if (ll[a] > ll[b]) swap(a, b);
if (rr[a] < ll[b])
xx[i] = rr[a], yy[i] = ll[b];
else
xx[i] = ll[a], yy[i] = ll[b];
x[i] = a, y[i] = b, p[i] = i;
add(x[i], y[i], i);
}
tarjan(1, 0);
sort(p, p + Q, nicer);
int ll = 0, rr = 0;
REP(i, Q) {
int u = p[i];
int l = xx[u], r = yy[u];
if (ll < l)
for (j = ll; j < l; ++j) trans(br[j]);
else if (ll > l)
for (j = l; j < ll; ++j) trans(br[j]);
if (rr > r)
for (j = rr; j > r; --j) trans(br[j]);
else if (rr < r)
for (j = r; j > rr; --j) trans(br[j]);
ans[u] = result;
if (lca[u] != x[u] && lca[u] != y[u])
ans[u] += (int64)v[c[lca[u]]] * w[sum[c[lca[u]]] + 1];
ll = l, rr = r;
}
REP(i, Q)
printf(fmt64"\n", ans[i]);
}
70分树分块:
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <ctime>
#include <algorithm>
#define REP(i, n) for (i = 0; i < (n); ++i)
#define FER(i, j) for (i = lst[j]; i; i = i->n)
#define int64 long long
#ifdef WIN32
#define fmt64 "%I64d"
#else
#define fmt64 "%lld"
#endif
#define oo 0x13131313
#define maxn 90002
#define BLOCK 300
using namespace std; double now;
template<class T> void read(T &x)
{
char c = getchar();
for (; '0' > c || c > '9'; c = getchar());
x = c - '0', c = getchar();
for (; '0' <= c && c <= '9'; c = getchar())
x = x * 10 + c - '0';
}
int n, m, Q, w[maxn], c[maxn]; short Mod[maxn]; int64 v[maxn];
int pa[maxn], size[maxn], ufs[maxn], dep[maxn], dfn[maxn], Dfn, ca[maxn];
int F, f, mark[maxn], tot, a[maxn], Mark;
int64 buf[maxn * BLOCK], *buft = buf, *ans[maxn], *Ans;
struct edge { int t; edge *n; } edges[maxn * 2], *adj = edges, *lst[maxn], *fr[maxn];
struct block { int a[BLOCK]; } blocks[maxn + BLOCK], *btot = blocks;
struct array { block *a[BLOCK]; int operator[](int); } sum[maxn];
int array::operator[](int b) { return --b, a[b / BLOCK]->a[Mod[b]]; }
void inherit(array &a, array &b, int pos)
{
block *&p = a.a[--pos / BLOCK];
memcpy(&a, &b, sizeof(array)), memcpy(btot, p, sizeof(block));
p = btot++, ++p->a[Mod[pos]];
}
int find(int x)
{
int f, g;
for (f = x; ufs[f] != f; f = ufs[f]);
for (; ufs[x] != x; x = g) g = ufs[x], ufs[x] = f;
return f;
}
void dfs(int u, int fa)
{
edge *e; int f = -1;
inherit(sum[u], sum[fa], c[u]), dep[u] = dep[fa] + 1, dfn[u] = ++Dfn;
FER(e, u) if (e->t != fa)
{
dfs(e->t, u), fr[e->t] = e, pa[e->t] = u;
if (!~f || size[f] + size[e->t] > BLOCK << 1)
f = e->t;
else
size[f] += size[e->t], ca[f] = u, ufs[e->t] = f;
}
ca[u] = ufs[u] = u, size[u] = 1;
if (~f && size[f] < BLOCK)
size[u] += size[f], ufs[f] = u;
}
void bfs(int S)
{
static int q[maxn]; int h, t; edge *e;
for (q[h = t = S] = 0; h; h = q[h])
{
Ans[h] = Ans[pa[h]] + v[c[h]] * w[sum[F][c[h]] - (sum[f][c[h]] << 1) + sum[h][c[h]] + (c[f] == c[h])];
FER(e, h) if (e->t != pa[h]) q[t = q[t] = e->t] = 0;
}
}
void init()/*pretreat for the answers between blocks*/
{
int i, j; edge *e;
REP(i, BLOCK) sum->a[i] = btot++;
dfs(1, 0);
fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
REP(i, n) if (find(i + 1) == i + 1)
{
F = f = ca[i + 1];
if (ans[f]) continue;
buft += n, Ans = ans[f] = buft - n - 1;
Ans[f] = v[c[f]] * w[1];
FER(e, f) if (e->t != pa[f]) bfs(e->t);
for (j = f; fr[j]; j = f)
{
f = pa[j];
Ans[f] = Ans[j] + v[c[f]] * w[sum[F][c[f]] - sum[f][c[f]] + 1];
for (e = fr[j]->n; e; e = e->n)
if (e->t != pa[f]) bfs(e->t);
}
}
}
int64 ask(int x, int y)
{
int i, j, f; int64 res = 0; ++Mark, tot = 0;
for (i = x; i; i = pa[i]) mark[i] = Mark;
for (f = y; mark[f] < Mark; f = pa[f]) a[++tot] = c[f];
a[++tot] = c[f];
for (i = x; i != f; i = pa[i]) a[++tot] = c[i];
sort(a + 1, a + tot + 1);
for (i = 1; i <= tot; )
for (j = 0, f = a[i]; a[i] == f && i <= tot; ++i)
res += v[f] * w[++j];
return res;
}
void Dfs(int u, int fa)
{
edge *e; pa[u] = fa;
for (e = lst[u]; e; e = e->n)
if (e->t != fa) Dfs(e->t, u);
}
void input()
{
int i; scanf("%d%d%d", &n, &m, &Q);
REP(i, m) read(v[i + 1]);
REP(i, n) read(w[i + 1]), Mod[i] = i % BLOCK;
REP(i, n - 1)
{
int a, b; read(a), read(b);
*adj = (edge){b, lst[a]}, lst[a] = adj++;
*adj = (edge){a, lst[b]}, lst[b] = adj++;
}
REP(i, n) read(c[i + 1]);
if (n <= 20000 && m <= 20000)
{
for (Dfs(1, 0); Q--; )
{
int t, x, y; read(t), read(x), read(y);
t ? printf(fmt64"\n", ask(x, y)) : c[x] = y;
}
exit(0);
}
}
int LCA(int x, int y)
{
for (; x != y; )
ufs[x] == ufs[y] ?
dep[x] > dep[y] ? x = pa[x] : y = pa[y] :
dep[ufs[x]] > dep[ufs[y]] ? x = pa[ufs[x]] : y = pa[ufs[y]];
return x;
}
int main()
{
freopen("park.in", "r", stdin);
freopen("park.out", "w", stdout);
now = clock();
input();
init();
fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
for (; Q--; )
{
int t, x, y; read(t), read(x), read(y);
if (!t) exit(0); if (dfn[y] < dfn[x]) swap(x, y);
int f = ca[ufs[x]], g = LCA(x, y);
int64 Ans = ans[f][y];
if (dep[f] < dep[g])
{
for (t = pa[g]; t != pa[f]; t = pa[t])
Ans -= v[c[t]] * w[sum[y][c[t]] - sum[t][c[t]] + 1];
f = g;
}
for (t = x; t != f; t = pa[t])
Ans += v[c[t]] * w[sum[y][c[t]] - (sum[g][c[t]] << 1) + sum[t][c[t]] + (c[g] == c[t])];
printf(fmt64"\n", Ans);
}
fprintf(stderr, "%.2lf\n", (clock() - now) / CLOCKS_PER_SEC);
}