tozangezan's diary

勝手にソースコードをコピペして利用しないでください。

AOJ 2644: Longest Match

Suffix Array+Segment Tree。
Suffix Arrayでlower_boundってどうするんだっけとか一瞬思ってしまった。

#include<stdio.h>
#include<algorithm>
#include<string.h>
using namespace std;
char str[210000];
char in[210000];
int q[910000];
int n;
int sa_k;
int rank[910000];
int tmp[910000];
int sa[910000];
bool compare_sa(int i,int j){
	if(rank[i]!=rank[j])return rank[i]<rank[j];
	else{
		int ri=i+sa_k<=n?rank[i+sa_k]:-1;
		int rj=j+sa_k<=n?rank[j+sa_k]:-1;
		return ri<rj;
	}
}
void construct_sa(){
	for(int i=0;i<=n;i++){
		sa[i]=i;
		rank[i]=i<n?str[i]:-1;
	}
	for(sa_k=1;sa_k<=n;sa_k*=2){
		sort(sa,sa+n+1,compare_sa);
		tmp[sa[0]]=0;
		for(int i=1;i<=n;i++){
			tmp[sa[i]]=tmp[sa[i-1]]+(compare_sa(sa[i-1],sa[i])?1:0);
		}
		for(int i=0;i<=n;i++){
			rank[i]=tmp[i];
		}
	}
}
int K;
int segtree[2][524288];
void update(int a,int b){
	a+=262144;
	while(a){
		segtree[0][a]=min(segtree[0][a],b);
		segtree[1][a]=max(segtree[1][a],b);
		a/=2;
	}
}
int lb(){
	int left=0;
	int right=n+1;
	while(left+1<right){
		int M=(left+right)/2;
		bool ok=true;
		int at=sa[M];
		while(at<n){
			if(at-sa[M]>=K){
				ok=false;break;
			}
			if(str[at]!=in[at-sa[M]]){
				if(str[at]>in[at-sa[M]])ok=false;
				break;
			}
			at++;
		}
		if(ok)left=M;
		else right=M;
	}
	return right;
}
int query(int a,int b,int c,int d,int e,int f){
	if(d<a||b<c||c>d){
		if(f)return -1;
		else return 999999999;
	}
	if(c<=a&&b<=d){
		return segtree[f][e];
	}
	if(f)return max(query(a,(a+b)/2,c,d,e*2,f),query((a+b)/2+1,b,c,d,e*2+1,f));
	else return min(query(a,(a+b)/2,c,d,e*2,f),query((a+b)/2+1,b,c,d,e*2+1,f));
}
int main(){
	scanf("%s",in);
	n=strlen(in);
	for(int i=0;in[i];i++)str[i]=in[i];
	str[n++]='~';
	construct_sa();
	int a;scanf("%d",&a);
	for(int i=0;i<524288;i++){
		segtree[0][i]=999999999;
		segtree[1][i]=-999999999;
	}
	for(int i=0;i<=n;i++){
		update(i,sa[i]);
	}
	for(int i=0;i<a;i++){
		scanf("%s",in);
		int len=strlen(in);
		int hz=len;
		K=len;
		int L1=lb();
		in[len-1]++;
		int R1=lb();
		scanf("%s",in);
		len=strlen(in);
		K=len;
		int L2=lb();
		in[len-1]++;
		int R2=lb();
		int lm=query(0,262143,L1,R1-1,1,0);
		int rm=query(0,262143,L2,R2-1,1,1);
		if(lm>rm||lm+hz>rm+len)printf("0\n");
		else printf("%d\n",rm-lm+len);
	}
}