AOJ 1293: Common Polynomial

構文解析+多項式ユークリッドの互除法

美しすぎるソースコード

#include<stdio.h>
#include<algorithm>
using namespace std;
long long ABS(long long a){return max(a,-a);}
long long gcd(long long a,long long b){
	a=ABS(a);b=ABS(b);
	while(a){
		b%=a;
		swap(a,b);
	}
	return b;
}
long long lcm(long long a,long long b){
	a=ABS(a);b=ABS(b);
	return a/gcd(a,b)*b;
}
struct Q{
	long long a,b;
	Q(){}
	Q(long long bunshi,long long bunbo){a=bunshi;b=bunbo;}
	Q operator+(Q c)const{
		long long s=lcm(b,c.b);
		return Q(s/b*a+s/c.b*c.a,s);
	}
	Q operator-(Q c)const{
		long long s=lcm(b,c.b);
		return Q(s/b*a-s/c.b*c.a,s);
	}
	Q operator*(Q c)const{
		return Q(a*c.a,b*c.b);
	}
};
struct P{
	Q p[10];
	P(){
		for(int i=0;i<10;i++)p[i]=Q(0,1);
	}
	P operator+(P a)const{
		P ret=P();
		for(int i=0;i<10;i++)ret.p[i]=p[i]+a.p[i];
		return ret;
	}
	P operator-(P a)const{
		P ret=P();
		for(int i=0;i<10;i++)ret.p[i]=p[i]-a.p[i];
		return ret;
	}
	P operator*(P a)const{
		P ret=P();
		for(int i=0;i<10;i++)
			for(int j=0;i+j<10;j++)
				ret.p[i+j]=ret.p[i+j]+p[i]*a.p[j];
		return ret;
	}
	int deg()const{
		for(int i=9;i>=0;i--)if(p[i].a!=0LL)return i;
		return -1;
	}
};
char in[2][110000];
int cur;
int ind;
P expr();
long long numb(){
	long long ret=0;
	while('0'<=in[ind][cur]&&in[ind][cur]<='9'){
		ret*=10;
		ret+=in[ind][cur]-'0';
		cur++;
	}
	//printf("%d %lld\n",cur,ret);
	return ret;
}
P term(){
	P ret=P();
	if(in[ind][cur]=='('){
		cur++;
		ret=expr();
		cur++;
	}else if(in[ind][cur]=='x'){
		ret.p[1]=Q(1,1);
		cur++;
	}else{
		ret.p[0]=Q(numb(),1);
	}
	if(in[ind][cur]=='^'){
		cur++;
		long long r=numb();
		P tmp=P();
		tmp.p[0]=Q(1,1);
		for(int i=0;i<r;i++)tmp=tmp*ret;
		ret=tmp;
	}
//	for(int i=0;i<10;i++)if(ret.p[i].a)printf("[%lld/%lld)x^%d ",ret.p[i].a,ret.p[i].b,i);
//	printf("\n");

	return ret;
}
P fact(){
	P ret=term();
//	printf(":%d %d %c\n",ind,cur,in[ind][cur]);
	while(in[ind][cur]&&in[ind][cur]!='+'&&in[ind][cur]!='-'&&in[ind][cur]!=')'){
		P tmp=term();
		ret=ret*tmp;
	//	printf("%d %d %c\n",ind,cur,in[ind][cur]);
	}
	return ret;
}
P expr(){
	P ret;
	if(in[ind][cur]=='-'){
		cur++;
		ret=P()-fact();
	}else ret=fact();
	while(in[ind][cur]=='+'||in[ind][cur]=='-'){
		char ch=in[ind][cur];
		cur++;
		P tmp=fact();
		if(ch=='+')ret=ret+tmp;
		else ret=ret-tmp;
	}
	return ret;
}
long long ans[10];
int main(){
	while(1){
		scanf("%s",in[0]);
		if(in[0][0]=='.')return 0;
		scanf("%s",in[1]);
		cur=ind=0;
		P s=expr();
		cur=0;ind=1;
		P t=expr();
//		for(int i=0;i<10;i++)if(s.p[i].a)printf("(%lld/%lld)x^%d ",s.p[i].a,s.p[i].b,i);
//		printf("\n");
//		for(int i=0;i<10;i++)if(t.p[i].a)printf("(%lld/%lld)x^%d ",t.p[i].a,t.p[i].b,i);
//		printf("\n");
		
		if(s.deg()<t.deg())swap(s,t);
		while(t.deg()!=-1){
			int D=t.deg();
			Q tmp=t.p[D];
			swap(tmp.a,tmp.b);
			for(int i=0;i<10;i++){
				t.p[i]=t.p[i]*tmp;
				long long tg=gcd(t.p[i].a,t.p[i].b);
				t.p[i].a/=tg;
				t.p[i].b/=tg;
			}
			int E=s.deg();
			for(int i=E-D;i>=0;i--){
				Q ks=s.p[i+D];
				for(int j=0;j<=D;j++){
					s.p[i+j]=s.p[i+j]-ks*t.p[j];
				}
			}
			for(int i=0;i<10;i++){
				long long tg=gcd(s.p[i].a,s.p[i].b);
				s.p[i].a/=tg;
				s.p[i].b/=tg;
			}
			swap(s,t);
		}
		long long lv=1;
		for(int i=0;i<10;i++)lv=lcm(lv,s.p[i].b);
		for(int i=0;i<10;i++)ans[i]=lv/s.p[i].b*s.p[i].a;
		long long gv=0;
		for(int i=0;i<10;i++)gv=gcd(gv,ans[i]);
		for(int i=0;i<10;i++)ans[i]/=gv;
		if(ans[s.deg()]<0)for(int i=0;i<10;i++)ans[i]=-ans[i];
		bool fi=true;
		for(int i=9;i>=0;i--){
			if(ans[i]==0)continue;
			if(!fi&&ans[i]>0)printf("+");
			if(ans[i]<0)printf("-");
			fi=false;
			if(ABS(ans[i])>1||i==0)printf("%lld",ABS(ans[i]));
			if(i)printf("x");
			if(i>1)printf("^%d",i);
		}
		printf("\n");
	}
}