P3349 [ZJOI2016]小星星 卡常实例

题目:P3349 [ZJOI2016]小星星

原理:设三维状态,一个当前节点,一个对应的节点,一个当前子树所包含对应节点,进行状压DP转移

第零版代码(部分):


1
2
const int N=20
int dp[N][N][1<<N];

严重超过内存限制时,会编译失败。

考虑到对于一个固定的子树大小,有效的状态很少,可以压缩存储。

第一版代码(50pts):


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;++i)
#define drp(i,j,k) for(int i=j;i>=k;--i)
#define zt(i,j) for(unsigned i=1;i<=cnt[j];++i)
#define int long long
using namespace std;
inline int read(){
    int x=0,f=0;
    char c=getchar();
    while(!isdigit(c)){
        f|=c=='-';
        c=getchar();
    }
    while(isdigit(c)){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return f?-x:x;
}
const int N=17,M=2e4;
int n,m,ans;
int g[N][N];
int dp[N][N][M],f[N][M];
int one[1<<N],id[1<<N];
int tpye[N+1][M],cnt[N+1];

int fir[N],nex[N<<1],to[N<<1],pos;
void add(int x,int y){
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
    swap(x,y);
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
}
//
int siz[N];
void dfs(int u,int fa){
    siz[u]=1;
    rep(i,0,n-1) dp[u][i][id[1<<i]]=1;
    for(int i=fir[u];i;i=nex[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u);
        int now=siz[u]+siz[v];
        rep(h,0,n-1)
          zt(p,siz[u]){
            int s=tpye[siz[u]][p];
            f[h][id[s]]=dp[u][h][id[s]];
            dp[u][h][id[s]]=0;
          }
        zt(p,now){
            int s=tpye[now][p];
            rep(h1,0,n-1){
                if(((1<<h1)&s)==0) continue;
                rep(h2,0,n-1){
                if(((1<<h2)&s)==0) continue;
                if(h1==h2) continue;
                if(!g[h1][h2]) continue;
                for(int st=(s-1)&s;st;st=(st-1)&s){
                    if(((1<<h1)&st)==0) continue;
                    if(((1<<h2)&(s^st))==0) continue;
                    if(one[st]!=siz[u]) continue;
                    if(one[s^st]!=siz[v]) assert(0);
                    dp[u][h1][id[s]]+=f[h1][id[st]]*dp[v][h2][id[s^st]];
                }
              }
            }
        }
        siz[u]+=siz[v];
    }
}

signed main(){
    n=read(),m=read();
   
    int c=0;
    rep(i,0,(1<<n)-1){
        int num=0;
        rep(j,0,n-1)
          num+=((1<<j)&i?1:0);
        one[i]=num;
        id[i]=++cnt[num];
        tpye[num][cnt[num]]=i;
        c=max(cnt[num],c);
    }
   
    rep(i,1,m){
        int x=read()-1,y=read()-1;
        g[x][y]=g[y][x]=1;
    }
    rep(i,1,n-1)
      add(read()-1,read()-1);
    dfs(0,-1);
    rep(i,0,n-1) ans+=dp[0][i][cnt[n]];
    printf("%lld\n",ans);
    return 0;
}

存在的一些问题:

① 大常数的数据类型


1
#define int long long

② 超大常数的枚举方式


1
zt(p,siz[u]+siz[v]

③ 死活要用时间换空间的拖拉id数组,让指针一直跳


1
id[i]=++cnt[num];

④ 经常换的[对应编号]一维不放在最后面


1
int dp[N][N][M];

第二版代码(80pts):


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;++i)
#define drp(i,j,k) for(int i=j;i>=k;--i)
#define zt(i,j) for(int i=1;i<=cnt[j];++i)
#define int long long
using namespace std;
inline int read(){
    int x=0,f=0;
    char c=getchar();
    while(!isdigit(c)){
        f|=c=='-';
        c=getchar();
    }
    while(isdigit(c)){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return f?-x:x;
}
const int N=17,M=1<<18;
int n,m;
int g[N][N];
int one[1<<N];
int tpye[N+1][M],cnt[N+1];
long long ans,dp[N][M][N],f[M][N];

int fir[N],nex[N<<1],to[N<<1],pos;
void add(int x,int y){
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
    swap(x,y);
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
}

int siz[N];
void dfs(int u,int fa){
    siz[u]=1;
    rep(i,0,n-1) dp[u][1<<i][i]=1;
    for(int i=fir[u];i;i=nex[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u);
        int now=siz[u]+siz[v];
        zt(p1,siz[u]) zt(p2,siz[v]) {
            int s1=tpye[siz[u]][p1];
            int s2=tpye[siz[v]][p2];
            if(s1&s2) continue;
            rep(h1,0,n-1) rep(h2,0,n-1){
                if(((1<<h1)&s1)==0) continue;
                if(((1<<h2)&s2)==0) continue;
                if(!g[h1][h2]) continue;
                dp[u][s1|s2][h1]+=dp[u][s1][h1]*dp[v][s2][h2];
            }
        }
        siz[u]+=siz[v];
    }
}

signed main(){
//  freopen("in.txt","r",stdin);
    n=read(),m=read();
    rep(i,0,(1<<n)-1){
        int num=0;
        rep(j,0,n-1)
          num+=((1<<j)&i?1:0);
        one[i]=num;
        tpye[num][++cnt[num]]=i;
    }
    rep(i,1,m){
        int x=read()-1,y=read()-1;
        g[x][y]=g[y][x]=1;
    }
    rep(i,1,n-1)
      add(read()-1,read()-1);
    int rt=0;
    dfs(rt,-1);
    rep(i,0,n-1) ans+=dp[rt][(1<<n)-1][i];
    printf("%lld\n",ans);
    return 0;
}

美观

很美观,尤其是dfs里的DP转移,错落有致,严谨而不啰嗦

但是效率和美观没有关系,还是有一些无用的情况可以舍弃

① 考虑到文中已经有了的四个 continue ,每当其变量刚出现时就可以进行减值。

② dp数组中有大量的地方权值为0,siz越小的地方越明显,可以把dp[v]放在前面求出来然后减值。

第三版代码(+O2 100pts):


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void dfs(int u,int fa){
    siz[u]=1;
    rep(i,0,n-1) dp[u][1<<i][i]=1;
    for(int i=fir[u];i;i=nex[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u);
        int su=siz[u],sv=siz[v];
        zt(p1,su) {
            int s1=tpye[su][p1];
            zt(p2,sv){
                int s2=tpye[sv][p2];
                if(s1&s2) continue;
                rep(h2,0,n-1) {
                    if(((1<<h2)&s2)==0) continue;
                    if(!dp[v][s2][h2]) continue;
                    rep(h1,0,n-1){
                        if(((1<<h1)&s1)==0) continue;
                        if(!g[h1][h2]) continue;
                        dp[u][s1|s2][h1]+=dp[u][s1][h1]*dp[v][s2][h2];
                    }
                }
            }
        }
        siz[u]+=siz[v];
    }
}

最后附上一个容斥的代码(跑得比暴力还慢/cy)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;++i)
#define drp(i,j,k) for(int i=j;i>=k;--i)
#define int unsigned long long
using namespace std;
inline int read(){
    int x=0,f=0;
    char c=getchar();
    while(!isdigit(c)){
        f|=c=='-';
        c=getchar();
    }
    while(isdigit(c)){
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return f?-x:x;
}
const int N=20;
int n,m,ans;
int g[N][N];

int fir[N],nex[N<<1],to[N<<1],pos;
void add(int x,int y){
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
    swap(x,y);
    to[++pos]=y;nex[pos]=fir[x];fir[x]=pos;
}

int vis[N],tot,dp[N][N],f[N];

void dfs(int u,int fa){
    rep(i,1,n) if(vis[i]) dp[u][i]=1;
    for(int i=fir[u];i;i=nex[i]){
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u);
        rep(i,1,n) f[i]=dp[u][i],dp[u][i]=0;
        rep(h1,1,n) rep(h2,1,n){
            if(!vis[h1]||!vis[h2]) continue;
            if(!g[h1][h2]) continue;
            dp[u][h1]+=f[h1]*dp[v][h2];
        }
    }
}

int sou(){
    dfs(1,0);
    int sum=0;
    rep(i,1,n)
      sum+=dp[1][i];
    return sum;
}

void rc(int u){
    if(u==n+1) {
        ans+=((n-tot)%2?-1:1)*sou();
        return ;
    };
    vis[u]=1;tot++;rc(u+1);
    vis[u]=0;tot--;rc(u+1);
}

signed main(){
   freopen("in.txt","r",stdin);
    n=read(),m=read();
    rep(i,1,m){
        int x=read(),y=read();
        g[x][y]=g[y][x]=1;
    }
    rep(i,1,n-1)
      add(read(),read());
    rc(1);
    printf("%lld\n",ans);
    return 0;
}