jzoj 6077.【GDOI2019模拟2019.3.22】K 君的游戏 分治ntt

Description
jzoj 6077.【GDOI2019模拟2019.3.22】K 君的游戏 分治ntt
Input
jzoj 6077.【GDOI2019模拟2019.3.22】K 君的游戏 分治ntt
Output
jzoj 6077.【GDOI2019模拟2019.3.22】K 君的游戏 分治ntt
Sample Input

4
1
2
3
4

Sample Output

0
1
499122177
831870295

Data Constraint
jzoj 6077.【GDOI2019模拟2019.3.22】K 君的游戏 分治ntt
分析:
考虑f[i]f[i]表示ii个节点胜的概率,显然1f[i]1-f[i]是输的概率。
考虑大小为ii的树是通过一棵大小为jj的树增加一棵大小为iji-j的子树构成。
考虑当前输的概率,显然就是原本输,增加的子树是赢的(从根走过去就输了),那么
1f[i]=j=1i1f[j](1f[ij])1-f[i]=\sum_{j=1}^{i-1}f[j]*(1-f[i-j])
也就是f[i]=1j=1i1f[j](1f[ij])f[i]=1-\sum_{j=1}^{i-1}f[j]*(1-f[i-j])
可以分治ntt解决,但是分治ntt保证有一个数组是已知的,但是都未知要怎么做呢?
考虑当前分治区间为[l,r][l,r],如果rl>lr-l>l,说明有些数是未知的,直接把[l,mid][l,mid][1,mid][1,mid]卷起来。
否则把[l,mid][l,mid][1,rl][1,r-l]卷起。此时要注意,对于一对x+yx+y这对转移和y+xy+x这对转移是都没有算的。因为其中大的数在小的数的区间未知。

代码:

#include <iostream>
#include <cstdio>
#include <cmath>
#define LL long long

const int maxn=3e5+7;
const LL mod=998244353;
const LL G=3;

using namespace std;

int T,n,p,len;
LL f[maxn],a[maxn],b[maxn],x[maxn],y[maxn],r[maxn],w[maxn],inv[maxn];

LL ksm(LL x,LL y)
{
	if (y==0) return 1;
	LL c=ksm(x,y/2);
	c=c*c%mod;
	if (y&1) c=c*x%mod;
	return c;
}

void ntt(LL *a,int f)
{
    for (int i=0;i<len;i++)
    {
        if (i<r[i]) swap(a[i],a[r[i]]);
    }
    w[0]=1;
    for (int i=2;i<=len;i<<=1)
    {
        LL wn;
        if (f==1) wn=ksm(G,(mod-1)/i);
             else wn=ksm(G,(mod-1)-(mod-1)/i);
        for (int j=i/2-2;j>=0;j-=2) w[j]=w[j/2];
        for (int j=1;j<i/2;j+=2) w[j]=(w[j-1]*wn)%mod;
        for (int j=0;j<len;j+=i)
        {
            for (int k=0;k<i/2;k++)
            {
                LL u=a[j+k],v=a[j+k+i/2]*w[k]%mod;
                a[j+k]=(u+v)%mod;
                a[j+k+i/2]=(u+mod-v)%mod;
            }
        }
    }
    if (f==-1)
    {
        LL inv=ksm(len,mod-2);
        for (int i=0;i<len;i++) a[i]=a[i]*inv%mod;
    }
}

void NTT(LL *a,LL *b,LL *c,LL n,LL m)
{   
    len=1;
    int k=0;
    while (len<=(n+m)) len*=2,k++;
    for (int i=0;i<len;i++)
    {
    	r[i]=(r[i>>1]>>1)|((i&1)<<(k-1));
    }
    for (int i=0;i<len;i++)
    {
        if (i<n) x[i]=a[i]; else x[i]=0;
        if (i<m) y[i]=b[i]; else y[i]=0;
    }   
    ntt(x,1); ntt(y,1);
    for (LL i=0;i<len;i++) c[i]=x[i]*y[i]%mod;
    ntt(c,-1);
}

void solve(int l,int r)
{
    if (l==r)
    {
    	if (l==1) f[l]=0;
		     else f[l]=(1+mod-f[l])%mod;
    	return;
    }
    int mid=(l+r)/2;
    solve(l,mid);
    if (r-l<l)
    {
    	for (int i=l;i<=mid;i++) a[i-l]=f[i];
        for (int i=1;i<=r-l;i++) b[i-1]=(1+mod-f[i])%mod;
        int n=mid-l+1,m=r-l;
        NTT(a,b,b,n,m);
        for (int i=mid+1;i<=r;i++) f[i]=(f[i]+inv[i-1]*b[i-l-1]%mod)%mod;
        
        for (int i=l;i<=mid;i++) a[i-l]=(1+mod-f[i])%mod;
        for (int i=1;i<=r-l;i++) b[i-1]=f[i];
        n=mid-l+1,m=r-l;
        NTT(a,b,b,n,m);
        for (int i=mid+1;i<=r;i++) f[i]=(f[i]+inv[i-1]*b[i-l-1]%mod)%mod;
    }
    else
    {
    	for (int i=l;i<=mid;i++) a[i-l]=f[i];
        for (int i=1;i<=mid;i++) b[i-1]=(1+mod-f[i])%mod;
        int n=mid-l+1,m=mid;
        NTT(a,b,b,n,m);
        for (int i=mid+1;i<=r;i++) f[i]=(f[i]+inv[i-1]*b[i-l-1]%mod)%mod;
    }
    solve(mid+1,r);
}

int main()
{
	freopen("game.in","r",stdin);
	freopen("game.out","w",stdout);
	scanf("%d",&T);
	n=8e4;
	for (int i=1;i<=n;i++) inv[i]=ksm(i,mod-2);		
	solve(1,n);	
	for (int i=1;i<=T;i++)
	{
		scanf("%d",&p);
		printf("%lld\n",f[p]);
	} 
}