読者です 読者をやめる 読者になる 読者になる

AOJ 2450: Do use segment tree

Heavy-Light Decomposition + 遅延評価 Segment Tree (しかも、区間和、左端からの和の最大値、右端からの和の最大値、区間の中での連続する和の最大値の4種類を持たないといけない)。
めちゃくちゃ重いが、特に大きなバグなく通すことができvery good.
 

#include<stdio.h>
#include<algorithm>
#include<vector>
#include<queue>
#define BEGIN_STACK_EXTEND(size) void * stack_extend_memory_ = malloc(size);void * stack_extend_origin_memory_;char * stack_extend_dummy_memory_ = (char*)alloca((1+(int)(((long long)stack_extend_memory_)&127))*16);*stack_extend_dummy_memory_ = 0;asm volatile("mov %%rsp, %%rbx\nmov %%rax, %%rsp":"=b"(stack_extend_origin_memory_):"a"((char*)stack_extend_memory_+(size)-1024));
#define END_STACK_EXTEND asm volatile("mov %%rax, %%rsp"::"a"(stack_extend_origin_memory_));free(stack_extend_memory_);
using namespace std;
int c[210000];
int sz[210000];
int eu[410000];
int conv[210000];
int rev[210000];
vector<int>g[210000];
int L[210000];
int par[210000];
int str[210000];
int ww[210000];
int wn[210000];
int cur;
int cur2;
void dfs(int a,int b){
	par[a]=b;
	conv[a]=cur;
	rev[cur]=a;
	cur++;
	L[a]=cur2;
	eu[cur2++]=a;
	sz[a]=1;
	str[a]=-1;
	for(int i=0;i<g[a].size();i++){
		if(g[a][i]==b)continue;
		dfs(g[a][i],a);
		eu[cur2++]=a;
		sz[a]+=sz[g[a][i]];
		if(str[a]==-1||sz[str[a]]<sz[g[a][i]]){
			str[a]=g[a][i];
		}
	}
	if(sz[str[a]]*2<sz[a])str[a]=-1;
}
// 0: sum
// 1: left-max
// 2: right-max
// 3: max
struct wolf{
	vector<int> node;
	vector<long long>segtree[4];
	vector<long long>lazy[4];
	int par;
	int at;
	int n;
	wolf(){
		par=at=-1;
	}
};
int mod=1000000007;
long long inf=(long long)mod*mod;
vector<wolf> hl;
void proc(int a,int b,int c,int d,int e,int node,int type){
	long long f=hl[node].lazy[type][e];
	hl[node].lazy[type][e*2]=hl[node].lazy[type][e*2+1]=hl[node].lazy[type][e];
	hl[node].lazy[type][e]=mod;
	if(type==0){
		hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=hl[node].segtree[type][e]/2;
	}else{
		if(f<0)hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=f;
		else hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=hl[node].segtree[type][e]/2;
	}
}
void update(int a,int b,int c,int d,int e,long long f,int node,int type){
	if(d<a||b<c)return;
	if(c<=a&&b<=d){
		hl[node].lazy[type][e]=f;
		if(type==0){
			hl[node].segtree[type][e]=f*(b-a+1);
		}else{
			if(f<0)hl[node].segtree[type][e]=f;
			else hl[node].segtree[type][e]=f*(b-a+1);
		}
		return;
	}
	if(hl[node].lazy[type][e]!=mod){
		proc(a,b,c,d,e,node,type);
	}
	update(a,(a+b)/2,c,d,e*2,f,node,type);
	update((a+b)/2+1,b,c,d,e*2+1,f,node,type);
	if(type==0){
		hl[node].segtree[type][e]=hl[node].segtree[type][e*2]+hl[node].segtree[type][e*2+1];
	}else if(type==1){
		hl[node].segtree[type][e]=max(hl[node].segtree[1][e*2],hl[node].segtree[0][e*2]+hl[node].segtree[1][e*2+1]);
	}else if(type==2){
		hl[node].segtree[type][e]=max(hl[node].segtree[2][e*2+1],hl[node].segtree[0][e*2+1]+hl[node].segtree[2][e*2]);
	}else{
		hl[node].segtree[type][e]=max(max(hl[node].segtree[3][e*2],hl[node].segtree[3][e*2+1]),hl[node].segtree[2][e*2]+hl[node].segtree[1][e*2+1]);
	}
}
vector<long long> query(int a,int b,int c,int d,int e,int node){
//	printf("%d %d %d %d %d %d\n",a,b,c,d,e,node);
	vector<long long>ret(4);
	if(d<a||b<c){
		ret[0]=0;ret[1]=ret[2]=ret[3]=-inf;return ret;
	}
	if(c<=a&&b<=d){
		for(int i=0;i<4;i++)ret[i]=hl[node].segtree[i][e];
		return ret;
	}
	for(int i=0;i<4;i++){
		if(hl[node].lazy[i][e]!=mod){
			proc(a,b,c,d,e,node,i);
		}
	}
	vector<long long> left=query(a,(a+b)/2,c,d,e*2,node);
	vector<long long> right=query((a+b)/2+1,b,c,d,e*2+1,node);
	ret[0]=left[0]+right[0];
	ret[1]=max(left[1],right[1]+left[0]);
	ret[2]=max(right[2],left[2]+right[0]);
	ret[3]=max(max(left[3],right[3]),left[2]+right[1]);
	return ret;
}
int lcaseg[1048576];
int lcaupd(int a,int b){
	a+=524288;
	while(a){
		lcaseg[a]=min(lcaseg[a],b);
		a/=2;
	}
}
int lcamin(int a,int b,int c,int d,int e){
	if(d<a||b<c)return mod;
	if(c<=a&&b<=d)return lcaseg[e];
	return min(lcamin(a,(a+b)/2,c,d,e*2),lcamin((a+b)/2+1,b,c,d,e*2+1));
}
int main(){
	BEGIN_STACK_EXTEND(250*1024*1024)
	int a,b;
	scanf("%d%d",&a,&b);
	for(int i=0;i<a;i++){
		scanf("%d",c+i);
	}
	for(int i=0;i<a-1;i++){
		int p,q;
		scanf("%d%d",&p,&q);
		p--;q--;
		g[p].push_back(q);
		g[q].push_back(p);
	}
	dfs(0,-1);
	hl.push_back(wolf());
	hl[0].node.push_back(0);
	queue<pair<int,int> >Q;
	Q.push(make_pair(0,0));
	ww[0]=wn[0]=0;
	while(Q.size()){
		int at=Q.front().first;
		int seq=Q.front().second;
		Q.pop();
		for(int i=0;i<g[at].size();i++){
			if(g[at][i]==par[at])continue;
			if(str[at]==g[at][i]){
				ww[g[at][i]]=seq;
				wn[g[at][i]]=hl[seq].node.size();
				hl[seq].node.push_back(g[at][i]);
				Q.push(make_pair(g[at][i],seq));
			}else{
				Q.push(make_pair(g[at][i],hl.size()));
				ww[g[at][i]]=hl.size();
				wn[g[at][i]]=0;
				hl.push_back(wolf());
				hl[hl.size()-1].node.push_back(g[at][i]);
				hl[hl.size()-1].par=seq;
				hl[hl.size()-1].at=wn[at];
			}
		}
	}
	
	for(int i=0;i<1048576;i++)lcaseg[i]=mod;
	for(int i=0;i<cur2;i++){
		lcaupd(i,conv[eu[i]]);
	}
	for(int i=0;i<hl.size();i++){
		hl[i].n=1;
		while(hl[i].n<hl[i].node.size())hl[i].n*=2;
		for(int j=0;j<4;j++){
			hl[i].segtree[j]=vector<long long>(hl[i].n*2);
			hl[i].lazy[j]=vector<long long>(hl[i].n*2);
		}
		for(int j=0;j<hl[i].node.size();j++){
			for(int k=0;k<4;k++){
				update(0,hl[i].n-1,j,j,1,c[hl[i].node[j]],i,k);
			}
		}
	}
	
	while(b--){
		int p,q,r,s;
		scanf("%d%d%d%d",&p,&q,&r,&s);
		q--;r--;
		if(L[q]>L[r])swap(q,r);
		int lca=rev[lcamin(0,524287,L[q],L[r],1)];
		int goal=ww[lca];
		vector<pair<int,pair<int,int> > >ql;
		vector<pair<int,pair<int,int> > >qr;
		if(goal==ww[q]){
			ql.push_back(make_pair(goal,make_pair(wn[q],wn[lca])));
		}else{
			ql.push_back(make_pair(ww[q],make_pair(wn[q],0)));
			int now=hl[ww[q]].par;
			int at=hl[ww[q]].at;
			while(now!=goal){
				ql.push_back(make_pair(now,make_pair(at,0)));
				at=hl[now].at;
				now=hl[now].par;
			}
			ql.push_back(make_pair(goal,make_pair(at,wn[lca])));
		}
		if(goal==ww[r]){
			if(r!=lca)qr.push_back(make_pair(goal,make_pair(wn[r],wn[lca]+1)));
		}else{
			qr.push_back(make_pair(ww[r],make_pair(wn[r],0)));
			int now=hl[ww[r]].par;
			int at=hl[ww[r]].at;
			while(now!=goal){
				qr.push_back(make_pair(now,make_pair(at,0)));
				at=hl[now].at;
				now=hl[now].par;
			}
			if(at!=wn[lca])qr.push_back(make_pair(goal,make_pair(at,wn[lca]+1)));
		}
		if(qr.size()){
			for(int i=qr.size()-1;i>=0;i--){
				swap(qr[i].second.first,qr[i].second.second);
				ql.push_back(qr[i]);
			}
		}
	//	printf("\n");
	//	for(int i=0;i<ql.size();i++){
	//		printf("%d (%d, %d)\n",ql[i].first,ql[i].second.first,ql[i].second.second);
	//	}
		if(p==1){
			for(int i=0;i<ql.size();i++){
				if(ql[i].second.first>ql[i].second.second)swap(ql[i].second.first,ql[i].second.second);
				for(int j=0;j<4;j++){
					update(0,hl[ql[i].first].n-1,ql[i].second.first,ql[i].second.second,1,s,ql[i].first,j);
				}
			}
		}else{
			long long ret=-inf;
			vector<long long>res[4];
			for(int i=0;i<ql.size();i++){
				vector<long long>t=query(0,hl[ql[i].first].n-1,min(ql[i].second.first,ql[i].second.second),max(ql[i].second.first,ql[i].second.second),1,ql[i].first);
				if(ql[i].second.first>ql[i].second.second)swap(t[1],t[2]);
				for(int j=0;j<4;j++)res[j].push_back(t[j]);
			}
			long long cur=-inf;
			for(int i=0;i<ql.size();i++){
		//		printf("%lld %lld %lld %lld\n",res[0][i],res[1][i],res[2][i],res[3][i]);
				ret=max(ret,res[3][i]);
				ret=max(ret,res[1][i]+cur);
				cur+=res[0][i];
				cur=max(cur,res[2][i]);
			}
			printf("%lld\n",ret);
		}
	}
	/*for(int i=0;i<4;i++){
		for(int j=0;j<hl[4].n*2;j++){
			printf("%lld ",hl[4].segtree[i][j]);
		}
		printf("\n");
	}*/
	END_STACK_EXTEND
}