WC2019 数树

题意:

WC2019 数树

数据范围:

WC2019 数树

Analysis:

首先膜拜这题出的实在是过于神仙,蒟蒻瑟瑟发抖,完全不会做。
此题分为三个部分,我们分别解决。
先浅显的分析一些性质:
我们发现我们只需要两棵树边的交集,然后把它们拿出来。
会形成森林,那么对于这片森林的每一个联通块,数字必须一样。
E1E1为第一棵树的边集,E2E2为第二颗树的边集。
由于nn个点的森林若一共有kk条边,那么一共会有nkn-k个联通块。
所以答案即为:ynE1E2y^{n-|E1\bigcap E2|}
对于Task1Task 1:
两棵树都已确定,只要确定交集有多少条边即可,用mapmap随意维护一下。
对于Task2Task 2:
我们发现第一棵树没有确定,我们需要枚举E1E1
这是一个集合枚举,方案数极难计算,不妨考虑容斥。
套上一个集合容斥,我们设:FS=TST1T(1)TT1FT1F_S=\sum_{T \subseteq S}\sum_{T1 \subseteq T}(-1)^{|T|-|T1|}F_{T1}
这是子集反演的套路,要记住。
FS=ynSF_S=y^{n-|S|},表示边集交集为SS的答案,那么我们可以得到:
ans=E1SE1E2TS(1)STynTans=\sum_{E1}\sum_{S \subseteq E1 \bigcap E2}\sum_{T \subseteq S} (-1)^{|S|-|T|}*y^{n-|T|} =SE2ynS(TS(1)STyST)GS=\sum_{S \subseteq E2}y^{n-|S|}*(\sum_{T \subseteq S}(-1)^{|S|-|T|}*y^{|S|-|T|})*G_S
其中GSG_S表示包含边集SS的边集E1E1有多少个,对于括号里面用二项式定理,我们能推得:ans=SE2ynS(1y)SGSans=\sum_{S \subseteq E2}y^{n-|S|}*(1-y)^{|S|}*G_S
我们先考虑GSG_S如何算,首先我们把边集SS里的点全部连上,然后现在是一片森林,我们需要继续连边,既然若干连通块独立,我们不妨将一个连通块看成一个点,然后求树的形态个数,每次连边都是在连通块里随意选一个点去连。
因此设在边集SS中共有kk个连通块,且每个连通块点数为aia_i,随意化一下公式就可以得到:GS=nk2i=1kaiG_S=n^{k-2}*\prod_{i=1}^ka_i
我们接着化原式:
ans=SE2yk(1y)nknk2i=1kaians=\sum_{S \subseteq E2}y^{k}*(1-y)^{n-k}*n^{k-2}*\prod_{i=1}^ka_i =(1y)nn2SE2i=1kny1yai=\frac{(1-y)^n}{n^2}\sum_{S \subseteq E2}\prod_{i=1}^k\frac{ny}{1-y}a_i
这样一个式子就简单很多了,既然子集数目很多,那我们考虑用DPDP去解决这个问题。
我们设k=ny1yk=\frac{ny}{1-y},我们可以设fi,jf_{i,j}表示包含ii的连通块大小为jj,且ii所在连通块的贡献还没计算,子树内的所有选择方案,的连通块的贡献之和,即式子中的最右边那部分,那么答案就是kjjf1,jk\sum_jj*f_{1,j}
但这样还不够,我们需要继续优化。
我们考虑设这个DPDP的生成函数为F(x)F(x),即F(x)=i>0fixiF(x)=\sum_{i>0} f_ix^i
我们再设Zi=kjfi,jZ_i=k\sum j*f_{i,j}
那么Fi(x)=xyson(Zy+Fy(x))F_i(x)=x\prod_{y \in son}(Z_y+F_y(x))
然后我们观察ZiZ_i中的形式,会发现类似于求导后的系数,于是我们可以得到。
Zx=kFx(1)=k(xyson(Zy+Fy(1)))=kyson(Zy+Fy(1))+(yson(Zy+Fy(1)))ysonkFy(1)Zy+Fy(1)Z_x=kF_x'(1)=k(x\prod_{y \in son}(Z_y+F_y(1)))'=k\prod_{y \in son}(Z_y+F_y(1))+(\prod_{y \in son}(Z_y+F_y(1)))\sum_{y \in son} \frac{kF_y'(1)}{Z_y+F_y(1)}
后面那一部分是通过如下式子得到的:
(i=1nF(x))=i=1nF(i)jiF(j)=i=1nF(i)j=1nF(j)F(j)(\prod_{i=1}^nF(x))'=\sum_{i=1}^nF'(i)\prod_{j \neq i}F(j)=\prod_{i=1}^nF(i)\sum_{j=1}^n \frac{F'(j)}{F(j)}
再观察一下这个式子的形式,我们就可以发现可以O(n)O(n)递推了。
我们额外设ti=Fi(1)t_i=F_i(1),那么可以得到:
Zi=ti(k+yZyZy+ty)Z_i=t_i(k+\sum_y \frac{Z_y}{Z_y+t_y}) ti=y(Zy+ty)t_i=\prod_y(Z_y+t_y)

Task3:

此时我们按照上一个TaskTask里面的方法推导,最终的式子大约如下:
ans=SynS(1y)SGS2ans=\sum_Sy^{n-|S|}*(1-y)^{|S|}*G^2_S =(1y)nn4Si=1kn2y1yai2=\frac{(1-y)^n}{n^4}\sum_S\prod_{i=1}^k\frac{n^2y}{1-y}a_i^2
一棵点数为ai2a_i-2的树,形态有aiai2a_i^{a_i-2}种,再分配标号方案为n!iai!\frac{n!}{\prod_ia_i!},然后再去重,除掉连通块个数的阶乘。
我们考虑生成函数F(x)=i>0n2y1yiii!F(x)=\sum_{i>0}\frac{n^2y}{1-y}*\frac{i^i}{i!}
ans=(1y)nn!n4i=1n[xn]f(x)ians=\frac{(1-y)^nn!}{n^4}\sum_{i=1}^n[x^n]f(x)^i
会发现后面是一个多项式expexp的形式,因此只需要求一次expexp即可。
复杂度O(nlogn)O(n \log n)

Code:

# include<cstdio>
# include<cstring>
# include<algorithm>
# include<map>
using namespace std;
const int N = 1e5 + 5;
const int mo = 998244353;
const int invg = (mo + 1) / 3;
typedef long long ll;
map <ll,int> Q;
int z[N << 3],ans[N << 3],rev[N << 3],F[N << 3],E[N << 3];
int A[N << 3],B[N << 3],C[N << 3],D[N << 3],inv[N << 3];
int f[N],g[N],fac[N],st[N],to[N << 1],nx[N << 1];
int n,Y,op,L,len,tot,X;
inline void add(int u,int v)
{
	to[++tot] = v,nx[tot] = st[u],st[u] = tot;
	to[++tot] = u,nx[tot] = st[v],st[v] = tot;
}
inline int pow(int x,int p)
{
	int ret = 1;
	for (; p ; p >>= 1,x = (ll)x * x % mo)
	if (p & 1) ret = (ll)ret * x % mo;
	return ret;
}
inline int inc(int x,int y) { return x + y >= mo ? x + y - mo : x + y; }
inline int dec(int x,int y) { return x - y < 0 ? x - y + mo : x - y; }
inline void dft(int *f,int n,int op)
{
	for (len = 1,L = 0 ; len <= n ; len <<= 1,++L);
	for (int i = 0 ; i < len ; ++i)
	{
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
		if (i < rev[i]) swap(f[i],f[rev[i]]);
	}
	for (int i = 1 ; i < len ; i <<= 1)
	{
		int wn = pow(~op ? 3 : invg,(mo - 1) / (i << 1));
		for (int j = 0 ; j < len ; j += (i << 1))
		{
			int w = 1;
			for (int k = 0 ; k < i ; ++k,w = (ll)w * wn % mo)
			{
				int x = f[j + k],y = (ll)w * f[i + j + k] % mo;
				f[j + k] = inc(x,y),f[i + j + k] = dec(x,y);
			}
		}
	}
	if (op == -1)
	{
		int x = pow(len,mo - 2);
		for (int i = 0 ; i < len ; ++i) f[i] = (ll)f[i] * x % mo;
	}
}
inline void Ginv(int *a,int *b,int n)
{
	if (n == 1) { b[0] = pow(a[0],mo - 2); return; }
	Ginv(a,b,n >> 1);
	for (int i = 0 ; i < n ; ++i) A[i] = a[i],B[i] = b[i];
	dft(A,n,1),dft(B,n,1);
	for (int i = 0 ; i < len ; ++i) A[i] = (ll)A[i] * B[i] % mo * B[i] % mo;
	dft(A,n,-1);
	for (int i = 0 ; i < n ; ++i) b[i] = dec(inc(b[i],b[i]),A[i]);
	for (int i = 0 ; i < len ; ++i) A[i] = B[i] = 0;
}
inline void Gln(int *a,int *b,int n)
{
	Ginv(a,C,n);
	for (int i = 0 ; i < n - 1 ; ++i) D[i] = (ll)(i + 1) * a[i + 1] % mo;
	dft(C,n,1),dft(D,n,1);
	for (int i = 0 ; i < len ; ++i) C[i] = (ll)C[i] * D[i] % mo;
	dft(C,n,-1),b[0] = 0;
	for (int i = 1 ; i < n ; ++i) b[i] = (ll)inv[i] * C[i - 1] % mo;
	for (int i = 0 ; i < len ; ++i) C[i] = D[i] = 0;
}
inline void Gexp(int *a,int *b,int n)
{
	if (n == 1) { b[0] = 1; return; }
	Gexp(a,b,n >> 1);
	for (int i = 0 ; i < n ; ++i) E[i] = b[i];
	Gln(b,F,n);
	for (int i = 0 ; i < n ; ++i) F[i] = dec(a[i],F[i]);
	F[0] = inc(F[0],1);
	dft(E,n,1),dft(F,n,1);
	for (int i = 0 ; i < len ; ++i) E[i] = (ll)E[i] * F[i] % mo;
	dft(E,n,-1);
	for (int i = 0 ; i < n ; ++i) b[i] = E[i];
	for (int i = 0 ; i < len ; ++i) E[i] = F[i] = 0;
}
inline void dfs(int x,int F)
{
	g[x] = 1,f[x] = X;
	for (int i = st[x] ; i ; i = nx[i])
	if (to[i] != F)
	{
		dfs(to[i],x),g[x] = (ll)g[x] * inc(f[to[i]],g[to[i]]) % mo;
		f[x] = (f[x] + (ll)f[to[i]] * pow(inc(f[to[i]],g[to[i]]),mo - 2) % mo) % mo;
	} f[x] = (ll)f[x] * g[x] % mo;
}
inline int calc()
{
	for (int i = 1 ; i < n ; ++i)
	{
		int u,v; scanf("%d%d",&u,&v);
		Q[(ll)u * (n - 1) + v] = Q[(ll)v * (n - 1) + u] = 1;
	} int cnt = 0;
	for (int i = 1 ; i < n ; ++i)
	{
		int u,v; scanf("%d%d",&u,&v);
		if (Q.count((ll)u * (n - 1) + v) || Q.count((ll)v * (n - 1) + u)) ++cnt;
	} return pow(Y,n - cnt);
}
inline int calc1()
{
	if (Y == 1) return pow(n,n - 2);
	for (int i = 1 ; i < n ; ++i)
	{
		int u,v; scanf("%d%d",&u,&v);
		add(u,v);
	} X = (ll)n * Y % mo * pow(dec(1,Y),mo - 2) % mo,dfs(1,0);
	X = pow(n,mo - 2);
	return (ll)pow(dec(1,Y),n) * X % mo * X % mo * f[1] % mo;
}
inline int calc2()
{
	if (Y == 1) return pow(n,2 * n - 4);
	X = (ll)n * n % mo * Y % mo * pow(dec(1,Y),mo - 2) % mo;
	int zs = 1; for (; zs <= n ; zs <<= 1); inv[1] = 1;
	for (int i = 2 ; i < zs ; ++i) inv[i] = (ll)(mo - mo / i) * inv[mo % i] % mo;
	int c = 1;
	for (int i = 1 ; i < zs ; ++i,c = (ll)c * i % mo)
		z[i] = (ll)pow(i,i) * pow(c,mo - 2) % mo * X % mo;
	Gexp(z,ans,zs); X = (ll)pow(dec(1,Y),n) * pow(inv[n],4) % mo;
	c = 1; for (int i = 2 ; i <= n ; ++i) c = (ll)c * i % mo;
	return (ll)ans[n] * X % mo * c % mo;
}
int main()
{
	scanf("%d%d%d",&n,&Y,&op);
	if (!op) printf("%d\n",calc());
	else if (op == 1) printf("%d\n",calc1());
	else printf("%d\n",calc2());
	return 0;
}