WC2019 数树
题意:
数据范围:
Analysis:
首先膜拜这题出的实在是过于神仙,蒟蒻瑟瑟发抖,完全不会做。
此题分为三个部分,我们分别解决。
先浅显的分析一些性质:
我们发现我们只需要两棵树边的交集,然后把它们拿出来。
会形成森林,那么对于这片森林的每一个联通块,数字必须一样。
设为第一棵树的边集,为第二颗树的边集。
由于个点的森林若一共有条边,那么一共会有个联通块。
所以答案即为:。
对于
两棵树都已确定,只要确定交集有多少条边即可,用随意维护一下。
对于
我们发现第一棵树没有确定,我们需要枚举。
这是一个集合枚举,方案数极难计算,不妨考虑容斥。
套上一个集合容斥,我们设:
这是子集反演的套路,要记住。
让,表示边集交集为的答案,那么我们可以得到:
其中表示包含边集的边集有多少个,对于括号里面用二项式定理,我们能推得:
我们先考虑如何算,首先我们把边集里的点全部连上,然后现在是一片森林,我们需要继续连边,既然若干连通块独立,我们不妨将一个连通块看成一个点,然后求树的形态个数,每次连边都是在连通块里随意选一个点去连。
因此设在边集中共有个连通块,且每个连通块点数为,随意化一下公式就可以得到:
我们接着化原式:
这样一个式子就简单很多了,既然子集数目很多,那我们考虑用去解决这个问题。
我们设,我们可以设表示包含的连通块大小为,且所在连通块的贡献还没计算,子树内的所有选择方案,的连通块的贡献之和,即式子中的最右边那部分,那么答案就是。
但这样还不够,我们需要继续优化。
我们考虑设这个的生成函数为,即。
我们再设。
那么。
然后我们观察中的形式,会发现类似于求导后的系数,于是我们可以得到。
。
后面那一部分是通过如下式子得到的:
再观察一下这个式子的形式,我们就可以发现可以递推了。
我们额外设,那么可以得到:
Task3:
此时我们按照上一个里面的方法推导,最终的式子大约如下:
一棵点数为的树,形态有,再分配标号方案为,然后再去重,除掉连通块个数的阶乘。
我们考虑生成函数。
。
会发现后面是一个多项式的形式,因此只需要求一次即可。
复杂度。
Code:
# include<cstdio>
# include<cstring>
# include<algorithm>
# include<map>
using namespace std;
const int N = 1e5 + 5;
const int mo = 998244353;
const int invg = (mo + 1) / 3;
typedef long long ll;
map <ll,int> Q;
int z[N << 3],ans[N << 3],rev[N << 3],F[N << 3],E[N << 3];
int A[N << 3],B[N << 3],C[N << 3],D[N << 3],inv[N << 3];
int f[N],g[N],fac[N],st[N],to[N << 1],nx[N << 1];
int n,Y,op,L,len,tot,X;
inline void add(int u,int v)
{
to[++tot] = v,nx[tot] = st[u],st[u] = tot;
to[++tot] = u,nx[tot] = st[v],st[v] = tot;
}
inline int pow(int x,int p)
{
int ret = 1;
for (; p ; p >>= 1,x = (ll)x * x % mo)
if (p & 1) ret = (ll)ret * x % mo;
return ret;
}
inline int inc(int x,int y) { return x + y >= mo ? x + y - mo : x + y; }
inline int dec(int x,int y) { return x - y < 0 ? x - y + mo : x - y; }
inline void dft(int *f,int n,int op)
{
for (len = 1,L = 0 ; len <= n ; len <<= 1,++L);
for (int i = 0 ; i < len ; ++i)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
if (i < rev[i]) swap(f[i],f[rev[i]]);
}
for (int i = 1 ; i < len ; i <<= 1)
{
int wn = pow(~op ? 3 : invg,(mo - 1) / (i << 1));
for (int j = 0 ; j < len ; j += (i << 1))
{
int w = 1;
for (int k = 0 ; k < i ; ++k,w = (ll)w * wn % mo)
{
int x = f[j + k],y = (ll)w * f[i + j + k] % mo;
f[j + k] = inc(x,y),f[i + j + k] = dec(x,y);
}
}
}
if (op == -1)
{
int x = pow(len,mo - 2);
for (int i = 0 ; i < len ; ++i) f[i] = (ll)f[i] * x % mo;
}
}
inline void Ginv(int *a,int *b,int n)
{
if (n == 1) { b[0] = pow(a[0],mo - 2); return; }
Ginv(a,b,n >> 1);
for (int i = 0 ; i < n ; ++i) A[i] = a[i],B[i] = b[i];
dft(A,n,1),dft(B,n,1);
for (int i = 0 ; i < len ; ++i) A[i] = (ll)A[i] * B[i] % mo * B[i] % mo;
dft(A,n,-1);
for (int i = 0 ; i < n ; ++i) b[i] = dec(inc(b[i],b[i]),A[i]);
for (int i = 0 ; i < len ; ++i) A[i] = B[i] = 0;
}
inline void Gln(int *a,int *b,int n)
{
Ginv(a,C,n);
for (int i = 0 ; i < n - 1 ; ++i) D[i] = (ll)(i + 1) * a[i + 1] % mo;
dft(C,n,1),dft(D,n,1);
for (int i = 0 ; i < len ; ++i) C[i] = (ll)C[i] * D[i] % mo;
dft(C,n,-1),b[0] = 0;
for (int i = 1 ; i < n ; ++i) b[i] = (ll)inv[i] * C[i - 1] % mo;
for (int i = 0 ; i < len ; ++i) C[i] = D[i] = 0;
}
inline void Gexp(int *a,int *b,int n)
{
if (n == 1) { b[0] = 1; return; }
Gexp(a,b,n >> 1);
for (int i = 0 ; i < n ; ++i) E[i] = b[i];
Gln(b,F,n);
for (int i = 0 ; i < n ; ++i) F[i] = dec(a[i],F[i]);
F[0] = inc(F[0],1);
dft(E,n,1),dft(F,n,1);
for (int i = 0 ; i < len ; ++i) E[i] = (ll)E[i] * F[i] % mo;
dft(E,n,-1);
for (int i = 0 ; i < n ; ++i) b[i] = E[i];
for (int i = 0 ; i < len ; ++i) E[i] = F[i] = 0;
}
inline void dfs(int x,int F)
{
g[x] = 1,f[x] = X;
for (int i = st[x] ; i ; i = nx[i])
if (to[i] != F)
{
dfs(to[i],x),g[x] = (ll)g[x] * inc(f[to[i]],g[to[i]]) % mo;
f[x] = (f[x] + (ll)f[to[i]] * pow(inc(f[to[i]],g[to[i]]),mo - 2) % mo) % mo;
} f[x] = (ll)f[x] * g[x] % mo;
}
inline int calc()
{
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
Q[(ll)u * (n - 1) + v] = Q[(ll)v * (n - 1) + u] = 1;
} int cnt = 0;
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
if (Q.count((ll)u * (n - 1) + v) || Q.count((ll)v * (n - 1) + u)) ++cnt;
} return pow(Y,n - cnt);
}
inline int calc1()
{
if (Y == 1) return pow(n,n - 2);
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
add(u,v);
} X = (ll)n * Y % mo * pow(dec(1,Y),mo - 2) % mo,dfs(1,0);
X = pow(n,mo - 2);
return (ll)pow(dec(1,Y),n) * X % mo * X % mo * f[1] % mo;
}
inline int calc2()
{
if (Y == 1) return pow(n,2 * n - 4);
X = (ll)n * n % mo * Y % mo * pow(dec(1,Y),mo - 2) % mo;
int zs = 1; for (; zs <= n ; zs <<= 1); inv[1] = 1;
for (int i = 2 ; i < zs ; ++i) inv[i] = (ll)(mo - mo / i) * inv[mo % i] % mo;
int c = 1;
for (int i = 1 ; i < zs ; ++i,c = (ll)c * i % mo)
z[i] = (ll)pow(i,i) * pow(c,mo - 2) % mo * X % mo;
Gexp(z,ans,zs); X = (ll)pow(dec(1,Y),n) * pow(inv[n],4) % mo;
c = 1; for (int i = 2 ; i <= n ; ++i) c = (ll)c * i % mo;
return (ll)ans[n] * X % mo * c % mo;
}
int main()
{
scanf("%d%d%d",&n,&Y,&op);
if (!op) printf("%d\n",calc());
else if (op == 1) printf("%d\n",calc1());
else printf("%d\n",calc2());
return 0;
}