Codeforces 755 G. PolandBall and Many Other Balls -开发者知识库

Codeforces 755 G. PolandBall and Many Other Balls -开发者知识库,第1张

Description

\(n\)個球,每組一個或者相鄰的兩個,求分成\(k\)組的方案數。\(n\leqslant 10^9,k<2^{15}\)

Solution

DP FNT.

轉移\(f[i][j]=f[i-1][j] f[i-1][j-1] f[i-2][j-1]\)

這個不是很好維護...可以看成多項式來做...可惜我也不太會...

還有一個轉移就是折半來做,前一段為\(x\),后一段為\(y\),\(x y=i\)。

若恰好在\(x,y\)之間可以分開那么方案數就是

\(f[i][j]=\sum_{a=0}^{k}\sum_{b=0}^{k}[a b=j]f[x][a]\times f[y][b]\)

如果不能那么就是

\(f[i][j]=\sum_{a=0}^{k}\sum_{b=0}^{k}[a b=j-1]f[x-1][a]\times f[y-1][b]\)

然后就可以倍增了...這個轉移可以用FNT優化...

我寫的常數巨大...卡了一晚上常數 = =。最后吧合並展開了少了幾次DFT的操作...

Code

#include <bits/stdc  .h>
using namespace std;

#define debug(a) cout<<(#a)<<"="<<a<<" "
//#define lc(o) ch[o][0]
//#define rc(o) ch[o][1]
#define lc (o<<1)
#define rc (o<<1|1)
#define mid ((l r)>>1)

typedef long long LL;
typedef pair<int,int> pr;
typedef vector<int> vi;
typedef vector<LL> vl;
typedef vector<string> vs;
const int N = 1<<17;
const int M = 32;
const int oo = 0x3f3f3f3f;
const LL  OO = 1e18;

const int p = 998244353;
LL Pow(LL a,LL b,LL r=1) { for(;b;b>>=1,a=a*a%p) if(b&1) r=r*a%p;return r; }
LL Pow(LL a,LL b,LL p,LL r=1) { for(;b;b>>=1,a=a*a%p) if(b&1) r=r*a%p;return r; }
LL inv(LL x) { return Pow(x,p-2); }
void Add(int &x,LL y) { x=(x y%p)%p; }
void Sub(int &x,LL y) { x=(x-y%p p)%p; }
void Mul(int &x,LL y) { x=x*(y%p)%p; }
int chkmax(LL &x,LL y) { return x<y?x=y,1:0; }
int chkmin(LL &x,LL y) { return x>y?x=y,1:0; }

inline LL in(LL x=0,char ch=getchar(),int v=1) {
	while(ch>'9' || ch<'0') v=ch=='-'?-1:v,ch=getchar();
	while(ch>='0' && ch<='9') x=x*10 ch-'0',ch=getchar();
	return x*v;
}
/*end*/

namespace Pol {
	const int g = 3;
	int pn = 1<<15;
	int nn = pn<<1;
	
	int rev[N];
	int w[M][N];
	
	void init(int n) { for(pn=1;pn<n;pn<<=1);nn=pn<<1; }
	void pre(int n=nn) {
		for(int i=0,j=0;i<n;i  ) {
			rev[i]=j;
			for(int k=n>>1;(j^=k)<k;k>>=1);
		}
		for(int i=1,l=0;i<=n;i<<=1,  l){
			w[l][0]=Pow(g,(p-1)/i);
			w[l][1]=Pow(w[l][0],p-2);
		}
	}
	void Rev(int a[],int n=nn) {
		for(int i=0;i<n;i  ) if(i>rev[i]) swap(a[i],a[rev[i]]);
	}
	void DFT(int a[],int r=1,int n=nn) {
		Rev(a);
		for(int i=2,l=1;i<=n;i<<=1,l  ) {
			for(int j=0;j<n;j =i) {
				int wn=1,wi=w[l][(r==-1)];
				for(int k=j;k<j i/2;k  ) {
					int t1=a[k],t2=1LL*wn*a[k i/2]%p;
					a[k]=(t1 t2)%p,a[k i/2]=(t1-t2 p)%p;
					wn=1LL*wn*wi%p;
				}
			}
		}if(~r) return;
		int inv=Pow(n,p-2);
		for(int i=0;i<n;i  ) a[i]=1LL*a[i]*inv%p;
	}
	void FNT(int a[],int b[],int c[],int n=nn) {
		DFT(a,1),DFT(b,1);
		for(int i=0;i<n;i  ) c[i]=1LL*a[i]*b[i]%p;
		DFT(c,-1);
	}
}


LL n,k;
int f[M][3][N],g[M][3][N];
int t4[N],t5[N];
int ans[2][3][N];

inline void get_2(int f[3][N]) {
	f[2][0]=1;
	for(int i=1;i<Pol::pn;i  ) f[2][i]=(0LL f[1][i] f[1][i-1] f[0][i-1])%p;
}
/*
inline void merge_p(int a[],int b[],int c[],int t) {
	memset(t1,0,sizeof(t1)),memset(t2,0,sizeof(t2));
	for(int i=0;i<Pol::pn;i  ) t1[i]=a[i],t2[i]=b[i];
	Pol::FNT(t1,t2,t3);
	for(int i=0;i<Pol::pn;i  ) if(i-t>=0) c[i]=(c[i] t3[i-t])%p;
}*/


int main() {
	ios::sync_with_stdio(false);
//	cout<<(sizeof(f) sizeof(ans) sizeof(t1)*3)/1024.0/1024.0<<endl;
	cin>>n>>k;
	Pol::init(max(8LL,k));
	Pol::pre();
	
{
	f[0][0][0]=0;
	f[0][1][0]=1;
	get_2(f[0]);
	int i=0;
	int *t1=g[i][0],*t2=g[i][1],*t3=g[i][2];
	for(int j=0;j<Pol::pn;j  )
		t1[j]=f[i][0][j],t2[j]=f[i][1][j],t3[j]=f[i][2][j];
	Pol::DFT(t1,1),Pol::DFT(t2,1),Pol::DFT(t3,1);
}	
	for(int i=1;(1LL<<i)<=n;i  ) {
		int *t1=g[i-1][0],*t2=g[i-1][1],*t3=g[i-1][2];
		for(int j=0;j<Pol::nn;j  ) {
			t4[j]=1LL*t1[j]*t1[j]%p;
			t5[j]=1LL*t1[j]*t2[j]%p;
			f[i][0][j]=1LL*t2[j]*t2[j]%p;
			f[i][1][j]=1LL*t2[j]*t3[j]%p;
		}
		Pol::DFT(f[i][0],-1);
		Pol::DFT(f[i][1],-1);
		Pol::DFT(t4,-1);
		Pol::DFT(t5,-1);
		for(int j=1;j<Pol::pn;j  ) {
			Add(f[i][0][j],t4[j-1]);
			Add(f[i][1][j],t5[j-1]);
		}
		for(int j=Pol::pn;j<Pol::nn;j  ) f[i][0][j]=f[i][1][j]=0;
		get_2(f[i]);
		t1=g[i][0],t2=g[i][1],t3=g[i][2];
		for(int j=0;j<Pol::pn;j  )
			t1[j]=f[i][0][j],t2[j]=f[i][1][j],t3[j]=f[i][2][j];
		Pol::DFT(t1,1),Pol::DFT(t2,1),Pol::DFT(t3,1);
	}
	int cur=0,fst=0;
	for(int i=0;i<M;i  ) if((n>>i)&1) {
		cur^=1;
		if(!fst) {
			for(int j=0;j<Pol::pn;j  )
	ans[cur][0][j]=f[i][0][j],ans[cur][1][j]=f[i][1][j],ans[cur][2][j]=f[i][2][j];
			fst=1;continue;
		}
		memset(ans[cur],0,sizeof(ans[cur]));
		int *t1=ans[cur^1][0],*t2=ans[cur^1][1],*t3=ans[cur^1][2];
		Pol::DFT(t1,1),Pol::DFT(t2,1),Pol::DFT(t3,1);
		int *ta=g[i][0],*tb=g[i][1],*tc=g[i][2];
		for(int j=0;j<Pol::nn;j  ) {
			t4[j]=1LL*t1[j]*ta[j]%p;
			t5[j]=1LL*t2[j]*ta[j]%p;
			ans[cur][0][j]=1LL*t2[j]*tb[j]%p;
			ans[cur][1][j]=1LL*t3[j]*tb[j]%p;
		}
		Pol::DFT(ans[cur][0],-1);
		Pol::DFT(ans[cur][1],-1);
		Pol::DFT(t4,-1);
		Pol::DFT(t5,-1);
		for(int j=1;j<Pol::pn;j  ) {
			Add(ans[cur][0][j],t4[j-1]);
			Add(ans[cur][1][j],t5[j-1]);
		}
		for(int j=Pol::pn;j<Pol::nn;j  ) ans[cur][0][j]=ans[cur][1][j]=0;
//		merge_p(ans[cur^1][0],f[i][0],ans[cur][0],1);
//		merge_p(ans[cur^1][1],f[i][1],ans[cur][0],0);
//		merge_p(ans[cur^1][1],f[i][0],ans[cur][1],1);
//		merge_p(ans[cur^1][2],f[i][1],ans[cur][1],0);
		get_2(ans[cur]);
	}
//	cout<<ans[cur][2][k]<<endl;
	for(int i=1;i<=k;i  ) cout<<ans[cur][2][i]<<" ";cout<<endl;
	return 0;
}

最佳答案:

本文经用户投稿或网站收集转载,如有侵权请联系本站。

发表评论

0条回复