AOJ 2450: Do use segment tree
Heavy-Light Decomposition + 遅延評価 Segment Tree (しかも、区間和、左端からの和の最大値、右端からの和の最大値、区間の中での連続する和の最大値の4種類を持たないといけない)。
めちゃくちゃ重いが、特に大きなバグなく通すことができvery good.
#include<stdio.h> #include<algorithm> #include<vector> #include<queue> #define BEGIN_STACK_EXTEND(size) void * stack_extend_memory_ = malloc(size);void * stack_extend_origin_memory_;char * stack_extend_dummy_memory_ = (char*)alloca((1+(int)(((long long)stack_extend_memory_)&127))*16);*stack_extend_dummy_memory_ = 0;asm volatile("mov %%rsp, %%rbx\nmov %%rax, %%rsp":"=b"(stack_extend_origin_memory_):"a"((char*)stack_extend_memory_+(size)-1024)); #define END_STACK_EXTEND asm volatile("mov %%rax, %%rsp"::"a"(stack_extend_origin_memory_));free(stack_extend_memory_); using namespace std; int c[210000]; int sz[210000]; int eu[410000]; int conv[210000]; int rev[210000]; vector<int>g[210000]; int L[210000]; int par[210000]; int str[210000]; int ww[210000]; int wn[210000]; int cur; int cur2; void dfs(int a,int b){ par[a]=b; conv[a]=cur; rev[cur]=a; cur++; L[a]=cur2; eu[cur2++]=a; sz[a]=1; str[a]=-1; for(int i=0;i<g[a].size();i++){ if(g[a][i]==b)continue; dfs(g[a][i],a); eu[cur2++]=a; sz[a]+=sz[g[a][i]]; if(str[a]==-1||sz[str[a]]<sz[g[a][i]]){ str[a]=g[a][i]; } } if(sz[str[a]]*2<sz[a])str[a]=-1; } // 0: sum // 1: left-max // 2: right-max // 3: max struct wolf{ vector<int> node; vector<long long>segtree[4]; vector<long long>lazy[4]; int par; int at; int n; wolf(){ par=at=-1; } }; int mod=1000000007; long long inf=(long long)mod*mod; vector<wolf> hl; void proc(int a,int b,int c,int d,int e,int node,int type){ long long f=hl[node].lazy[type][e]; hl[node].lazy[type][e*2]=hl[node].lazy[type][e*2+1]=hl[node].lazy[type][e]; hl[node].lazy[type][e]=mod; if(type==0){ hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=hl[node].segtree[type][e]/2; }else{ if(f<0)hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=f; else hl[node].segtree[type][e*2]=hl[node].segtree[type][e*2+1]=hl[node].segtree[type][e]/2; } } void update(int a,int b,int c,int d,int e,long long f,int node,int type){ if(d<a||b<c)return; if(c<=a&&b<=d){ hl[node].lazy[type][e]=f; if(type==0){ hl[node].segtree[type][e]=f*(b-a+1); }else{ if(f<0)hl[node].segtree[type][e]=f; else hl[node].segtree[type][e]=f*(b-a+1); } return; } if(hl[node].lazy[type][e]!=mod){ proc(a,b,c,d,e,node,type); } update(a,(a+b)/2,c,d,e*2,f,node,type); update((a+b)/2+1,b,c,d,e*2+1,f,node,type); if(type==0){ hl[node].segtree[type][e]=hl[node].segtree[type][e*2]+hl[node].segtree[type][e*2+1]; }else if(type==1){ hl[node].segtree[type][e]=max(hl[node].segtree[1][e*2],hl[node].segtree[0][e*2]+hl[node].segtree[1][e*2+1]); }else if(type==2){ hl[node].segtree[type][e]=max(hl[node].segtree[2][e*2+1],hl[node].segtree[0][e*2+1]+hl[node].segtree[2][e*2]); }else{ hl[node].segtree[type][e]=max(max(hl[node].segtree[3][e*2],hl[node].segtree[3][e*2+1]),hl[node].segtree[2][e*2]+hl[node].segtree[1][e*2+1]); } } vector<long long> query(int a,int b,int c,int d,int e,int node){ // printf("%d %d %d %d %d %d\n",a,b,c,d,e,node); vector<long long>ret(4); if(d<a||b<c){ ret[0]=0;ret[1]=ret[2]=ret[3]=-inf;return ret; } if(c<=a&&b<=d){ for(int i=0;i<4;i++)ret[i]=hl[node].segtree[i][e]; return ret; } for(int i=0;i<4;i++){ if(hl[node].lazy[i][e]!=mod){ proc(a,b,c,d,e,node,i); } } vector<long long> left=query(a,(a+b)/2,c,d,e*2,node); vector<long long> right=query((a+b)/2+1,b,c,d,e*2+1,node); ret[0]=left[0]+right[0]; ret[1]=max(left[1],right[1]+left[0]); ret[2]=max(right[2],left[2]+right[0]); ret[3]=max(max(left[3],right[3]),left[2]+right[1]); return ret; } int lcaseg[1048576]; int lcaupd(int a,int b){ a+=524288; while(a){ lcaseg[a]=min(lcaseg[a],b); a/=2; } } int lcamin(int a,int b,int c,int d,int e){ if(d<a||b<c)return mod; if(c<=a&&b<=d)return lcaseg[e]; return min(lcamin(a,(a+b)/2,c,d,e*2),lcamin((a+b)/2+1,b,c,d,e*2+1)); } int main(){ BEGIN_STACK_EXTEND(250*1024*1024) int a,b; scanf("%d%d",&a,&b); for(int i=0;i<a;i++){ scanf("%d",c+i); } for(int i=0;i<a-1;i++){ int p,q; scanf("%d%d",&p,&q); p--;q--; g[p].push_back(q); g[q].push_back(p); } dfs(0,-1); hl.push_back(wolf()); hl[0].node.push_back(0); queue<pair<int,int> >Q; Q.push(make_pair(0,0)); ww[0]=wn[0]=0; while(Q.size()){ int at=Q.front().first; int seq=Q.front().second; Q.pop(); for(int i=0;i<g[at].size();i++){ if(g[at][i]==par[at])continue; if(str[at]==g[at][i]){ ww[g[at][i]]=seq; wn[g[at][i]]=hl[seq].node.size(); hl[seq].node.push_back(g[at][i]); Q.push(make_pair(g[at][i],seq)); }else{ Q.push(make_pair(g[at][i],hl.size())); ww[g[at][i]]=hl.size(); wn[g[at][i]]=0; hl.push_back(wolf()); hl[hl.size()-1].node.push_back(g[at][i]); hl[hl.size()-1].par=seq; hl[hl.size()-1].at=wn[at]; } } } for(int i=0;i<1048576;i++)lcaseg[i]=mod; for(int i=0;i<cur2;i++){ lcaupd(i,conv[eu[i]]); } for(int i=0;i<hl.size();i++){ hl[i].n=1; while(hl[i].n<hl[i].node.size())hl[i].n*=2; for(int j=0;j<4;j++){ hl[i].segtree[j]=vector<long long>(hl[i].n*2); hl[i].lazy[j]=vector<long long>(hl[i].n*2); } for(int j=0;j<hl[i].node.size();j++){ for(int k=0;k<4;k++){ update(0,hl[i].n-1,j,j,1,c[hl[i].node[j]],i,k); } } } while(b--){ int p,q,r,s; scanf("%d%d%d%d",&p,&q,&r,&s); q--;r--; if(L[q]>L[r])swap(q,r); int lca=rev[lcamin(0,524287,L[q],L[r],1)]; int goal=ww[lca]; vector<pair<int,pair<int,int> > >ql; vector<pair<int,pair<int,int> > >qr; if(goal==ww[q]){ ql.push_back(make_pair(goal,make_pair(wn[q],wn[lca]))); }else{ ql.push_back(make_pair(ww[q],make_pair(wn[q],0))); int now=hl[ww[q]].par; int at=hl[ww[q]].at; while(now!=goal){ ql.push_back(make_pair(now,make_pair(at,0))); at=hl[now].at; now=hl[now].par; } ql.push_back(make_pair(goal,make_pair(at,wn[lca]))); } if(goal==ww[r]){ if(r!=lca)qr.push_back(make_pair(goal,make_pair(wn[r],wn[lca]+1))); }else{ qr.push_back(make_pair(ww[r],make_pair(wn[r],0))); int now=hl[ww[r]].par; int at=hl[ww[r]].at; while(now!=goal){ qr.push_back(make_pair(now,make_pair(at,0))); at=hl[now].at; now=hl[now].par; } if(at!=wn[lca])qr.push_back(make_pair(goal,make_pair(at,wn[lca]+1))); } if(qr.size()){ for(int i=qr.size()-1;i>=0;i--){ swap(qr[i].second.first,qr[i].second.second); ql.push_back(qr[i]); } } // printf("\n"); // for(int i=0;i<ql.size();i++){ // printf("%d (%d, %d)\n",ql[i].first,ql[i].second.first,ql[i].second.second); // } if(p==1){ for(int i=0;i<ql.size();i++){ if(ql[i].second.first>ql[i].second.second)swap(ql[i].second.first,ql[i].second.second); for(int j=0;j<4;j++){ update(0,hl[ql[i].first].n-1,ql[i].second.first,ql[i].second.second,1,s,ql[i].first,j); } } }else{ long long ret=-inf; vector<long long>res[4]; for(int i=0;i<ql.size();i++){ vector<long long>t=query(0,hl[ql[i].first].n-1,min(ql[i].second.first,ql[i].second.second),max(ql[i].second.first,ql[i].second.second),1,ql[i].first); if(ql[i].second.first>ql[i].second.second)swap(t[1],t[2]); for(int j=0;j<4;j++)res[j].push_back(t[j]); } long long cur=-inf; for(int i=0;i<ql.size();i++){ // printf("%lld %lld %lld %lld\n",res[0][i],res[1][i],res[2][i],res[3][i]); ret=max(ret,res[3][i]); ret=max(ret,res[1][i]+cur); cur+=res[0][i]; cur=max(cur,res[2][i]); } printf("%lld\n",ret); } } /*for(int i=0;i<4;i++){ for(int j=0;j<hl[4].n*2;j++){ printf("%lld ",hl[4].segtree[i][j]); } printf("\n"); }*/ END_STACK_EXTEND }