tozangezan's diary

勝手にソースコードをコピペして利用しないでください。

AOJ 2377: ThreeRooks

包除原理で数える。
answer = (3つおく全ての方法の数)-(2個同じ線上にある方法の数)+(3つがL字に置かれる方法の数)+(3つが同じ線上にある方法の数)*2

あとはN<100000であることを使ってデータ構造系でがんばって数える。特にL字に置かれる方法の数を数えるのがややこしいが、JOI2011合宿のDragonみたいなジャンルで、難度もいい勝負。JOIの過去問で日ごろから慣れてればまあ解けはする。

オーバーフローに気をつけましょう。

#include<stdio.h>
#include<algorithm>
#include<set>
#include<vector>
using namespace std;
long long mod=1000000007;
long long os=166666668;
long long ot=500000004;
int p[110000];
int q[110000];
int xz[110000];
int yz[110000];
vector<int>Y[110000];
vector<int>X[110000];
int next[110000];
long long segtree[262144];
long long query(int a,int b,int c,int d,int e){
	if(c>d)return 0;
	if(d<a||b<c)return 0;
	if(c<=a&&b<=d)return segtree[e];
	return query(a,(a+b)/2,c,d,e*2)+query((a+b)/2+1,b,c,d,e*2+1);
}
void update(int a,int b){
	a+=131072;
	segtree[a]=b;
	a/=2;
	while(a){
		segtree[a]=segtree[a*2]+segtree[a*2+1];
		a/=2;
	}
}
int sum[110000];
int main(){
	int a,b,n;
	scanf("%d%d%d",&a,&b,&n);
	if((long long)a*b-n<3){
		printf("0\n");return 0;
	}
	long long ms=((long long)a*b-n)%mod;
	for(int i=0;i<n;i++){
		scanf("%d%d",p+i,q+i);
		xz[i]=p[i];yz[i]=q[i];
	}
	std::sort(xz,xz+n);
	std::sort(yz,yz+n);
	for(int i=0;i<n;i++){
		Y[lower_bound(xz,xz+n,p[i])-xz].push_back(q[i]);
		X[lower_bound(yz,yz+n,q[i])-yz].push_back(p[i]);
	}
	for(int i=0;i<n;i++){
		std::sort(X[i].begin(),X[i].end());
		std::sort(Y[i].begin(),Y[i].end());
	}
	int xs=0;int ys=0;
	for(int i=0;i<n;i++){
		if(i==0||xz[i]!=xz[i-1]){
			sum[i+1]=1;
			xs++;
		}
		if(i==0||yz[i]!=yz[i-1])ys++;
	}
	for(int i=1;i<=n;i++)sum[i]+=sum[i-1];
	
	long long ret=ms*(ms-1)%mod*(ms-2)%mod*os%mod;
	//printf("%lld\n",ret);
	long long nega=0;
	nega=(nega+(long long)(b-ys)*a%mod*(a-1)%mod*ot%mod*(ms-2))%mod;
	nega=(nega+(long long)(a-xs)*b%mod*(b-1)%mod*ot%mod*(ms-2))%mod;
	ret=(ret+(long long)(b-ys)*a%mod*(a-1)%mod*(a-2)%mod*os*2)%mod;
	ret=(ret+(long long)(a-xs)*b%mod*(b-1)%mod*(b-2)%mod*os*2)%mod;
	for(int i=0;i<n;i++){
		if(i&&xz[i]==xz[i-1])continue;
		int at=0;
		for(int j=0;j<Y[i].size();j++){
			int to=Y[i][j];
			if(to-at>1){
				long long len=to-at;
				nega=(nega+len*(len-1)%mod*ot%mod*(ms-2))%mod;
			}
			if(to-at>2){
				long long len=to-at;
				ret=(ret+len*(len-1)%mod*(len-2)%mod*os%mod*2)%mod;
			}
			at=to+1;
		}
		if(at<b-1){
			long long len=b-at;
			nega=(nega+len*(len-1)%mod*ot%mod*(ms-2))%mod;
		}
		if(at<b-2){
			long long len=b-at;
			ret=(ret+len*(len-1)%mod*(len-2)%mod*os%mod*2)%mod;
		}
	}
	for(int i=0;i<n;i++){
		if(i&&yz[i]==yz[i-1])continue;
		int at=0;
		for(int j=0;j<X[i].size();j++){
			int to=X[i][j];
			if(to-at>1){
				long long len=to-at;
				nega=(nega+len*(len-1)%mod*ot%mod*(ms-2))%mod;
			}
			if(to-at>2){
				long long len=to-at;
				ret=(ret+len*(len-1)%mod*(len-2)%mod*os%mod*2)%mod;
			}
			at=to+1;
		}
		if(at<a-1){
			long long len=a-at;
			nega=(nega+len*(len-1)%mod*ot%mod*(ms-2))%mod;
		}
		if(at<a-2){
			long long len=a-at;
			ret=(ret+len*(len-1)%mod*(len-2)%mod*os%mod*2)%mod;
		}
	}
	
	for(int i=0;i<n;i++){
		Y[i].push_back(b);
		if(i&&xz[i]==xz[i-1])continue;
		update(i,Y[i][0]-1);
	}
	int row=0;
	for(int i=0;i<n;i++){
		if(i&&yz[i]==yz[i-1])continue;
		long long hb=yz[i]-row;
		row=yz[i]+1;
		ret=(ret+hb*(a-1)%mod*((query(0,131071,0,n-1,1)%mod+(long long)(b-1)*(a-xs)%mod)%mod))%mod;
		int col=0;
		int xf=0;
		for(int j=0;j<X[i].size();j++){
			int xi=lower_bound(xz,xz+n,X[i][j])-xz;
			if(X[i][j]-col>=2){
				ret=(ret+(long long)(X[i][j]-col-1)*(query(0,131071,xf,xi-1,1)%mod+(long long)(b-1)*(X[i][j]-col-(sum[xi]-sum[xf]))%mod))%mod;
			}
			col=X[i][j]+1;
			xf=xi+1;
			next[xi]++;
			update(xi,Y[xi][next[xi]]-Y[xi][next[xi]-1]-2);
		}
		if(a-col>=2){
			ret=(ret+(long long)(a-col-1)*(query(0,131071,xf,n-1,1)%mod+(long long)(b-1)*(a-col-(sum[n]-sum[xf]))%mod))%mod;
		}
	}
	//printf("%lld %lld\n",ret,nega);
	ret=(ret+(long long)(b-row)*(a-1)%mod*(query(0,131071,0,n-1,1)%mod+(long long)(b-1)*(a-xs)%mod)%mod)%mod;
	ret=(ret+mod-nega)%mod;
	printf("%lld\n",ret);
}