点分治——学习笔记

点分治

  • 点分治基础框架(常用于树上路径问题)
    1. 找1个合适的分治中心rt
    2. 统计答案ans += solve(T,rt)
    3. 对所有rt的子节点v,递归调用work(v)
int work(u){//统计sub(u)中的合法路径
    rt = find_rt();//找重心
    ans = solve(rt);
    for v∈son[rt]:
        ans += work(v)
    return ans;
}
  • 模板例题:

一.

  • 题目链接: https://www.luogu.com.cn/problem/P4178
  • 解题思路: 本题很常见的套路题,以每个点为LCA求出经过该点并且距离小于等于 K K K的数量,然后再加上一每个点为根距离小于等于 K K K的数量就是最终结果。然后利用点分治框架将时间复杂度降低为 O ( N ) O(N) O(N)

//#define LOCAL
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define mem(a, b) memset(a,b,sizeof(a))
#define sz(a) (int)a.size()
#define INF 0x3f3f3f3f
#define DNF 0x7f
#define DBG printf("this is a input\n")
#define fi first
#define se second
#define mk(a, b) make_pair(a,b)
#define pb push_back
#define LF putchar('\n')
#define SP putchar(' ')
#define p_queue priority_queue
#define CLOSE ios::sync_with_stdio(0); cin.tie(0)
#define sz(a) (int)a.size()
#define pii pair <int,int>
template<typename T>
void read(T &x) {x = 0;char ch = getchar();ll f = 1;while(!isdigit(ch)){if(ch == '-')f *= -1;ch = getchar();}while(isdigit(ch)){x = x * 10 + ch - 48; ch = getchar();}x *= f;}
template<typename T, typename... Args>
void read(T &first, Args& ... args) {read(first);read(args...);}
template<typename T>
void write(T arg) {T x = arg;if(x < 0) {putchar('-'); x =- x;}if(x > 9) {write(x / 10);}putchar(x % 10 + '0');}
template<typename T, typename ... Ts>
void write(T arg, Ts ... args) {write(arg);if(sizeof...(args) != 0) {putchar(' ');write(args ...);}}
using namespace std;

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll lcm(ll a, ll b) {
    return a / gcd(a, b) * b;
}

const int N = 5e4 + 5;
ll ans = 0;
int n, head[N], cnt, k;
struct BIT
{
    int tree[1000005] , ed = 4e5 + 5;
    void init()
    {
        mem(tree,0) ;
    }
    int lowbit(int k)
    {
        return k & -k;
    }
    void add(int x , int k)
    {
        while(x <= ed)
        {
            tree[x] += k ;
            x += lowbit(x) ;
        }
    }
    int sum(int x)
    {
        int ans = 0 ;
        while(x != 0)
        {
            ans += tree[x] ;
            x -= lowbit(x) ;
        }
        return ans ;
    }
    int query(int l , int r)
    {
        return sum(r) - sum(l - 1) ;
    }
} bit ;
struct node {
    int t, next, w;
}edge[N << 1];
void add (int f, int t, int w)
{
    edge[cnt].w = w;
    edge[cnt].t = t;
    edge[cnt].next = head[f];
    head[f] = cnt ++;
}
int rt, siz[N], vis[N];
int dis[N], tot;
void dfs_rt (int u , int fa, int s)
{
    siz[u] = 1;
    int RtCnt = 0;
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (v != fa && !vis[v])
        {
            dfs_rt(v, u, s);
            siz[u] += siz[v];
            RtCnt = max (RtCnt, siz[v]);
        }
    }
    RtCnt = max (RtCnt, s - siz[u]);
    if (2 * RtCnt <= s) rt = u;
}
void dfs1 (int u , int fa)
{
    ++tot;
    if (dis[u] <= k)
    {
        ans ++;
        ans += bit.sum (max(0,k - dis[u]));
    }
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t, w = edge[i].w;
        if (v != fa && !vis[v])
        {
            dis[v] = dis[u] + w;
            dfs1 (v, u);
        }
    }
}
void dfs2 (int u, int fa, int p)
{
    if (dis[u] != 0)
        bit.add (dis[u],p);
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (v != fa && !vis[v])
            dfs2 (v, u, p);
    }
}
void solve (int u)
{
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t , w = edge[i].w;
        if (!vis[v])
        {
            tot = 0;
            dis[v] = w;
            dfs1 (v, u);
            siz[v] = tot;
            dfs2 (v, u , 1);
        }
    }
    dfs2 (u, 0 , -1);
}
void work (int u , int fa, int s)
{
    dfs_rt (u , fa, s);
    int now = rt;
    dis[now] = 0;
    vis[now] = 1;
    solve (now);
    for (int i = head[now] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (!vis[v])
            work (v, now, siz[v]);
    }
}
int main()
{
    mem (head, -1);
    bit.init();
    read (n);
    for (int i = 1 ; i < n ; i ++)
    {
        int u , v, w;
        read (u, v, w);
        add (u, v, w);
        add (v, u, w);
    }
    read (k);
    work (1, 1, n);
    write (ans), LF;
}

二.

  • 题目链接: https://www.luogu.com.cn/problem/P4149
  • 解题思路: 本题和上一题类似,记录一下深度即可

//#define LOCAL
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define mem(a, b) memset(a,b,sizeof(a))
#define sz(a) (int)a.size()
#define INF 0x3f3f3f3f
#define DNF 0x7f
#define DBG printf("this is a input\n")
#define fi first
#define se second
#define mk(a, b) make_pair(a,b)
#define pb push_back
#define LF putchar('\n')
#define SP putchar(' ')
#define p_queue priority_queue
#define CLOSE ios::sync_with_stdio(0); cin.tie(0)
#define sz(a) (int)a.size()
#define pii pair <int,int>
template<typename T>
void read(T &x) {x = 0;char ch = getchar();ll f = 1;while(!isdigit(ch)){if(ch == '-')f *= -1;ch = getchar();}while(isdigit(ch)){x = x * 10 + ch - 48; ch = getchar();}x *= f;}
template<typename T, typename... Args>
void read(T &first, Args& ... args) {read(first);read(args...);}
template<typename T>
void write(T arg) {T x = arg;if(x < 0) {putchar('-'); x =- x;}if(x > 9) {write(x / 10);}putchar(x % 10 + '0');}
template<typename T, typename ... Ts>
void write(T arg, Ts ... args) {write(arg);if(sizeof...(args) != 0) {putchar(' ');write(args ...);}}
using namespace std;

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll lcm(ll a, ll b) {
    return a / gcd(a, b) * b;
}

const int N = 2e5 + 5;
int ans = INF;
int n, head[N], cnt, k;
struct node {
    int t, next, w;
}edge[N << 1];
void add (int f, int t, int w)
{
    edge[cnt].w = w;
    edge[cnt].t = t;
    edge[cnt].next = head[f];
    head[f] = cnt ++;
}
int t[2000000];
int rt, siz[N], vis[N];
ll dis[N];
int tot, dep[N];
void dfs_rt (int u , int fa, int s)
{
    siz[u] = 1;
    int rtCnt = 0;
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (v != fa && !vis[v])
        {
            dfs_rt (v, u , s);
            siz[u] += siz[v];
            rtCnt = max (rtCnt, siz[v]);
        }
    }
    rtCnt = max (rtCnt , s - siz[u]);
    if (2 * rtCnt <= s) rt = u;
}
void dfs1 (int u , int fa)
{
    ++ tot;
    if (k>= dis[u]) {
        ans = min(ans,t[k-dis[u]] + dep[u]);
    }
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t , w = edge[i].w;
        if (!vis[v] && v != fa)
        {
            dep[v] = dep[u] + 1;
            dis[v] = dis[u] + w;
            dfs1 (v, u);
        }
    }
}
void dfs2 (int u , int fa, int p)
{
    if (dis[u] <= k)
    {
        if (p == 1)
            t[dis[u]] = min (t[dis[u]], dep[u]);
        else
            t[dis[u]] = INF;
    }
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (!vis[v] && v != fa)
            dfs2 (v, u, p);
    }
}
void solve (int u)
{
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t , w = edge[i].w;
        if (!vis[v])
        {
            tot = 0;
            dep[v] = 1;
            dis[v] = w;
            dfs1 (v, u);
            siz[v] = tot;
            dfs2 (v, u, 1);
        }
    }
    dfs2 (u, 0 , -1);
}
void work (int u , int fa, int s)
{
    dfs_rt (u , fa, s);
    int now = rt;
    vis[now] = 1;
    dis[now] = 0;
    solve (now);
    for (int i = head[now] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (!vis[v])
            work (v, now, siz[v]);
    }
}
int main()
{
    mem (head, -1);
    mem (t, INF);
    t[0] = 0;
    read (n, k);
    for (int i = 1 ; i < n ; i ++)
    {
        int u , v, w;
        read (u, v, w);
        add (u + 1, v + 1, w);
        add (v + 1, u + 1, w);
    }
    work (1, 1, n);
    if (ans == INF) ans = -1;
    write (ans), LF;
}

三.

  • 题目链接: https://www.luogu.com.cn/problem/P2634
  • 解题思路: 统计一下路径中0,1,2的数量即可,框架不变,注意 d f s 2 dfs2 dfs2函数哪些点该加该减

//#define LOCAL
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define mem(a, b) memset(a,b,sizeof(a))
#define sz(a) (int)a.size()
#define INF 0x3f3f3f3f
#define DNF 0x7f
#define DBG printf("this is a input\n")
#define fi first
#define se second
#define mk(a, b) make_pair(a,b)
#define pb push_back
#define LF putchar('\n')
#define SP putchar(' ')
#define p_queue priority_queue
#define CLOSE ios::sync_with_stdio(0); cin.tie(0)
#define sz(a) (int)a.size()
#define pii pair <int,int>
template<typename T>
void read(T &x) {x = 0;char ch = getchar();ll f = 1;while(!isdigit(ch)){if(ch == '-')f *= -1;ch = getchar();}while(isdigit(ch)){x = x * 10 + ch - 48; ch = getchar();}x *= f;}
template<typename T, typename... Args>
void read(T &first, Args& ... args) {read(first);read(args...);}
template<typename T>
void write(T arg) {T x = arg;if(x < 0) {putchar('-'); x =- x;}if(x > 9) {write(x / 10);}putchar(x % 10 + '0');}
template<typename T, typename ... Ts>
void write(T arg, Ts ... args) {write(arg);if(sizeof...(args) != 0) {putchar(' ');write(args ...);}}
using namespace std;

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll lcm(ll a, ll b) {
    return a / gcd(a, b) * b;
}

const int N = 100005;
int n, head[N], cnt, ans;
struct node {
    int t, w, next;
}edge[N << 1];
void add (int f, int t, int w)
{
    edge[cnt].t = t;
    edge[cnt].w = w;
    edge[cnt].next = head[f];
    head[f] = cnt ++;
}
int vis[N], dis[N], siz[N], rt, tot;
int t[10];
void dfs_rt (int u ,int fa, int s)
{
    siz[u] = 1;
    int rtCnt = 0;
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (v != fa && !vis[v])
        {
            dfs_rt(v, u, s);
            siz[u] += siz[v];
            rtCnt = max (rtCnt , siz[v]);
        }
    }
    rtCnt = max (rtCnt, s - siz[u]);
    if (rtCnt * 2 <= s) rt = u;
}
void dfs1 (int u , int fa)
{
    ++ tot;
    if (dis[u] == 0)
        ans += 1, ans += t[0];
    ans += t[3-dis[u]];

    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t, w = edge[i].w;
        if (v != fa && !vis[v])
        {
            dis[v] = (dis[u] + w)%3;
            dfs1 (v, u);
        }
    }
}
void dfs2 (int u, int fa, int p)
{
    if (fa != -1)
        t[dis[u]] += p;
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (v != fa && !vis[v])
            dfs2 (v, u, p);
    }
}
void solve (int u)
{
    for (int i = head[u] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t , w = edge[i].w;
        if (!vis[v])
        {
            tot = 0;
            dis[v] = w % 3;
            dfs1 (v, u);
            siz[v] = tot;
            dfs2 (v, u, 1);
        }
    }
    dfs2 (u, -1 , -1);
}
void work (int u , int fa, int s)
{
    dfs_rt (u , fa, s);
    int now = rt;
    vis[now] = 1;
    dis[now] = 0;
    solve (now);
    for (int i = head[now] ; i != -1 ; i = edge[i].next)
    {
        int v = edge[i].t;
        if (!vis[v])
            work (v, now , siz[v]);
    }
}
int main()
{
    mem (head, -1);
    read (n);
    for (int i = 1 ; i < n ; i ++)
    {
        int u , v,  w;
        read (u, v , w);
        add (u, v, w);
        add (v, u, w);
    }
    work (1, 1, n);
    //cout << ans << endl;
    ans *= 2;
    ans += n;
    int p = gcd (ans,n*n);
    printf ("%d/%d",ans/p,n*n/p);

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值