树上点分治

这几天一直想完成之前的点分治,但都在写其他题,今天终于写完了点分治,怕忘,就写一下博客。

【引言】

由于树具有一般的图没有的特点,所以在竞赛中的应用更广。

在一些树上路径问题中,暴力求解时间复杂度过高,往往需要一些更为高效的算法,点分治就是其中之一。

【思想】

在树上,取一个点,对于询问计算所有经过这个点的情况,再递归下去处理每一棵子树的情况。因为是递归处理,为了使递归的次数更少,一般都是要取使以其为根,所有子树中节点最多的子树拥有的节点最少,称这个点为“重心”。

重心求法如下:

(1)任取一个根,dfs一次求出每个节点的子树大小。

(2)记录每一个点最大子树的节点数。

(3)最大子树的节点最少的点便是重心

代码如下:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
[cce_cpp]
void Getroot(int x, int fa)  {
    son[x] = 1;
    Max[x] = 0;
    int next;
    for (int i = last[x]; i ;i = e[i].last)
        if (!use[next = e[i].t] && next != fa)  {
            Getroot(next, x);
            son[x] += son[next];
            Max[x] = max(Max[x], son[next]);
        }
    Max[x] = max(Max[x], Son - son[x]);//因为这个点不一定是根,要计算它上的节点数,用树的大小减去(它的子节点数+1)即可
    if (Max[x] < Max[root]) root = x;
}
[/cce_cpp]

时间复杂度O(n)。

求出重心接下来就要解决问题了
(1)关于询问,以重心为根,求出所有经过当前重心的情况(即子节点1–>重心–>子节点2)。因为为了节省时间,是把所有子节点的情况一起算,如果两个子节点来自同一棵子树,那肯定是不行的,所以还要减去所以来自同一棵子树的子节点的情况。
(2)标记重心
(3)递归用相同的方法处理每一棵子树。
代码如下:


1
2
3
4
5
6
7
8
9
10
11
12
13
void sol(int x)  {
    root = 0;
    Getroot(x, 0);
    use[root] = 1;
    ans += Cal(root, 0);
    int next;
    for (int i = last[root]; i ;i = e[i].last)
        if (!use[next = e[i].t])  {
            ans -= Cal(next, e[i].val);
            Son = son[next];
            sol(next);
        }
}

点分治的思想基本上如上所述。

【例题】

poj1741(这题题库上好像没有)

最经典的点分治,求出每一棵子树,再计算符合题意的情况。

代码如下:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<climits>

using namespace std;

const int N = 100005;

int Read()  {
    char ch;
    int val = 0, opt = 1;
    while ( !( isdigit( ch = getchar() ) || ch == '-') );
    if (ch == '-') opt = -1;
    else val = ch - '0';
    while (isdigit( ch = getchar() )) (val *= 10) += ch - '0';
    return val * opt;
}

struct node {
    int t, val, last;
}p[2 * N];

int last[N], n, k, use[N], Max[N], root, dis[N], son[N], ans, tot, v[N], num, Son;

void ins(int s, int t, int val)  {
    p[++tot].t = t;
    p[tot].val = val;
    p[tot].last = last[s];
    last[s] = tot;
}

void getroot(int x, int fa)  {
    int next;
    son[x] = 1, Max[x] = 0;
    for (int i = last[x]; i ; i = p[i].last)  {
        if ((next = p[i].t) != fa && !use[next])  {
            getroot(next, x);
            son[x] += son[next];
            Max[x] = max(Max[x], son[next]);
        }
    }
    Max[x] = max(Max[x], Son - son[x]);
    if (Max[x] < Max[root]) root = x;
}

void Find_dis(int x, int fa)  {
    int next;
    v[++num] = dis[x];
    for (int i = last[x]; i ;i = p[i].last)
        if (!use[next = p[i].t] && next != fa)  {
            dis[next] = dis[x] + p[i].val ;
            Find_dis(next, x);
        }
}

int cal(int x, int va){
    dis[x] = va;
    num = 0;
    Find_dis(x, 0);
    sort(v + 1, v + 1 + num);
    int ti = 1, sum = 0, w = num;
    while (w > ti)  {
        if (v[w] + v[ti] <= k)  {
            sum += w - ti;
            ti++;
        }
        else w--;
    }
    return sum;
}

void sol(int x)  {
    root = 0;
    getroot(x, 0);
    use[root] = 1;
    ans += cal(root, 0);
    int next;
    for (int i = last[root]; i ;i = p[i].last)
        if (!use[next = p[i].t])  {
            ans -= cal(next, p[i].val);
            Son = son[next];
            sol(next);
        }
}

int main(void)  {
    //freopen("tree.in", "r", stdin);
    //freopen("tree.out", "w", stdout);
   
    int x, y, z;
    Max[root] = INT_MAX;
    while (( n = Read() ))  {
        k = Read();
        ans = tot = 0;
        memset(last, 0, sizeof last);
        memset(use, 0, sizeof use);
        for (int i = 1;i < n; i++)  {
            x = Read(), y = Read(), z = Read();
            ins(x, y, z);
            ins(y, x, z);
        }
        Son = n;
        sol(1);
        cout << ans << endl;
    }
   
    fclose(stdin);
    fclose(stdout);
    return 0;
}

聪聪可可

这题与上题略有不同,主要是计算答案有点不一样。

代码如下:


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<climits>
#include<cstring>
 
using namespace std;
 
const int N = 20010;
 
int Read()  {
    char ch;
    int val = 0, opt = 1;
    while (!( isdigit( ch = getchar() ) || ch == '-'));
    if (ch == '-') opt = -1;
    else val = ch - '0';
    while (isdigit(ch = getchar())) (val *= 10) += ch - '0';
    return val * opt;
}
 
struct node  {
    int t, val, last;
}e[2 * N];
 
int n, last[N], tot, root = 0, Max[N], son[N], Son, use[N], Num[3], ans, dis[N];
 
void Ins(int s, int t, int val)  {
    e[++tot].t = t;
    e[tot].val = val;
    e[tot].last = last[s];
    last[s] = tot;
}
 
void Getroot(int x, int fa)  {
    son[x] = 1;
    Max[x] = 0;
    int next;
    for (int i = last[x]; i ;i = e[i].last)
        if (!use[next = e[i].t] && next != fa)  {
            Getroot(next, x);
            son[x] += son[next];
            Max[x] = max(Max[x], son[next]);
        }
    Max[x] = max(Max[x], Son - son[x]);
    if (Max[x] < Max[root]) root = x;
}
 
void Find_dis(int x, int fa)  {
    int next;
    ++Num[dis[x] % 3];
    for (int i = last[x]; i ;i = e[i].last)
        if (!use[next = e[i].t] && next != fa)  {
            dis[next] = (dis[x] + e[i].val) % 3;
            Find_dis(next, x);
        }
}
 
int Cal(int x, int v)  {
    dis[x] = v;
    memset(Num, 0, sizeof Num);
    Find_dis(x, 0);
    return Num[0] * Num[0] + Num[1] * Num[2] * 2;
}
 
void sol(int x)  {
    root = 0;
    Getroot(x, 0);
    use[root] = 1;
    ans += Cal(root, 0);
    int next;
    for (int i = last[root]; i ;i = e[i].last)
        if (!use[next = e[i].t])  {
            ans -= Cal(next, e[i].val);
            Son = son[next];
            sol(next);
        }
}
 
int Gcd(int a, int b)  {
    return b ? Gcd(b, a % b) : a;
}
 
int main()  {
    //freopen("cckk.in", "r", stdin);
    //freopen("cckk.out", "w", stdout);
   
    Son = n = Read();
    for (int i = 1;i < n; i++)  {
        int x = Read(), y = Read(), z = Read();
        Ins(x, y, z);
        Ins(y, x, z);
    }
    Max[0] = INT_MAX;
    sol(1);
    cout << ans / Gcd(ans, n * n) << "/" << n*n / Gcd(ans, n*n);
   
    fclose(stdin);
    fclose(stdout);
    return 0;
}

发表评论

邮箱地址不会被公开。 必填项已用*标注