数据结构与算法:树上差分
树上差分
前言
树上问题比cf那些逆天思维题善良多了()
一、树上点差分
树上点差分解决的问题就是单次O(1)处理对节点进行修改,最后整合出最终结果。公式为:
若想在节点a到节点b的路径上的所有点统一增加v,可以给节点a增加v,节点b增加v,两点的lca减去v,lca的父节点减去v即可。所有操作都做完后,去树上dfs,每个节点最终的点权就是其所有孩子点权的累加和。
1.Max Flow P
#include <bits/stdc++.h>
using namespace std;
/* /\_/\
* (= ._.)
* / > \>
*/
#define endl '\n'
#define dbg(x) cout<<#x<<" "<<x<<endl
#define vdbg(a) cout<<#a<<endl; for(auto x:a) cout<<x<<" ";cout<<endl
#define INF 1e9
#define INFLL 1e18
#define YES cout<<"YES"<<endl;return ;
#define Yes cout<<"Yes"<<endl;return ;
#define NO cout<<"NO"<<endl;return ;
#define No cout<<"No"<<endl;return ;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const int MAXN=5e4+5;
const int LIMIT=20;
int n,k;
vector<vector<int>>g(MAXN);
vector<int>val(MAXN);
int power;
vector<int>deep(MAXN);
vector<vector<int>>stjump(MAXN,vector<int>(LIMIT));
int log2()
{
int ans=0;
while((1<<ans)<=(n>>1))
{
ans++;
}
return ans;
}
void build(int u,int fa)
{
deep[u]=deep[fa]+1;
stjump[u][0]=fa;
for(int p=1;(1<<p)<=deep[u];p++)
{
stjump[u][p]=stjump[stjump[u][p-1]][p-1];
}
for(auto v:g[u])
{
if(v!=fa)
{
build(v,u);
}
}
}
int lca(int a,int b)
{
if(deep[a]<deep[b])
{
swap(a,b);
}
for(int p=power;p>=0;p--)
{
if(deep[stjump[a][p]]>=deep[b])
{
a=stjump[a][p];
}
}
if(a==b)
{
return a;
}
for(int p=power;p>=0;p--)
{
if(stjump[a][p]!=stjump[b][p])
{
a=stjump[a][p];
b=stjump[b][p];
}
}
return stjump[a][0];
}
void dfs(int u,int fa)
{
for(auto v:g[u])
{
if(v!=fa)
{
dfs(v,u);
val[u]+=val[v];
}
}
}
void solve()
{
cin>>n>>k;
for(int i=0,u,v;i<n-1;i++)
{
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
power=log2();
build(1,0);
int a,b;
while(k--)
{
cin>>a>>b;
val[a]++;
val[b]++;
int l=lca(a,b);
val[l]--;
val[stjump[l][0]]--;
}
dfs(1,0);
int ans=0;
for(int i=1;i<=n;i++)
{
ans=max(ans,val[i]);
}
cout<<ans<<endl;
}
void init()
{
}
signed main()
{
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
int t=1;
//cin>>t;
init();
while(t--)
{
solve();
}
return 0;
}
这个题其实就是每次两个点路径上的所有点都增加1,然后统计最后的最大点权,那么就是个树上点差分的板子题了。所以就是每次根据st表求出两点的lca和lca的父节点,然后按公式操作,最后dfs一遍整合答案即可。
2.松鼠的新家
#include <bits/stdc++.h>
using namespace std;
/* /\_/\
* (= ._.)
* / > \>
*/
#define endl '\n'
#define dbg(x) cout<<#x<<" "<<x<<endl
#define vdbg(a) cout<<#a<<endl; for(auto x:a) cout<<x<<" ";cout<<endl
#define INF 1e9
#define INFLL 1e18
#define YES cout<<"YES"<<endl;return ;
#define Yes cout<<"Yes"<<endl;return ;
#define NO cout<<"NO"<<endl;return ;
#define No cout<<"No"<<endl;return ;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const int MAXN=3e5+5;
const int LIMIT=20;
vector<int>a(MAXN);
vector<vector<int>>g(MAXN);
int n;
int power;
int log2()
{
int ans=0;
while((1<<ans)<=(n>>1))
{
ans++;
}
return ans;
}
vector<int>deep(MAXN);
vector<vector<int>>stjump(MAXN,vector<int>(LIMIT));
void build(int u,int fa)
{
deep[u]=deep[fa]+1;
stjump[u][0]=fa;
for(int p=1;(1<<p)<=deep[u];p++)
{
stjump[u][p]=stjump[stjump[u][p-1]][p-1];
}
for(auto v:g[u])
{
if(v!=fa)
{
build(v,u);
}
}
}
int lca(int a,int b)
{
if(deep[a]<deep[b])
{
swap(a,b);
}
for(int p=power;p>=0;p--)
{
if(deep[stjump[a][p]]>=deep[b])
{
a=stjump[a][p];
}
}
if(a==b)
{
return a;
}
for(int p=power;p>=0;p--)
{
if(stjump[a][p]!=stjump[b][p])
{
a=stjump[a][p];
b=stjump[b][p];
}
}
return stjump[a][0];
}
vector<int>val(MAXN);
void dfs(int u,int fa)
{
for(auto v:g[u])
{
if(v!=fa)
{
dfs(v,u);
val[u]+=val[v];
}
}
}
void solve()
{
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
for(int i=0,u,v;i<n-1;i++)
{
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
power=log2();
build(1,0);
for(int i=2;i<=n;i++)
{
val[a[i-1]]++;
val[a[i]]++;
int anc=lca(a[i-1],a[i]);
val[anc]--;
val[stjump[anc][0]]--;
}
dfs(1,0);
for(int i=2;i<=n;i++)
{
val[a[i]]--;
}
for(int i=1;i<=n;i++)
{
cout<<val[i]<<endl;
}
}
void init()
{
}
signed main()
{
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
int t=1;
//cin>>t;
init();
while(t--)
{
solve();
}
return 0;
}
这个题其实也还好,就是树上点差分的模板题。唯一需要注意的是,在统计完差分信息后,道中的每个点以及终点都被多算了一次,所以要统一减去一点。
3.最小化旅行的价格总和
class Solution {
public:
const int MAXN=50+5;
const int LIMIT=6;
vector<int>price;
vector<array<int,2>>dp;
vector<int>cnts;
vector<vector<int>>g;
vector<int>deep;
vector<vector<int>>stjump;
int power;
int log2(int n)
{
int ans=0;
while((1<<ans)<=(n>>1))
{
ans++;
}
return ans;
}
void build(int n)
{
power=log2(n);
price.resize(MAXN);
dp.resize(MAXN);
cnts.resize(MAXN);
g.resize(MAXN);
deep.resize(MAXN);
stjump.resize(MAXN,vector<int>(LIMIT));
}
int minimumTotalPrice(int n, vector<vector<int>>& e, vector<int>& p, vector<vector<int>>& t) {
build(n);
for(int i=0;i<n;i++)
{
price[i+1]=p[i];
}
for(auto edge:e)
{
g[edge[0]+1].push_back(edge[1]+1);
g[edge[1]+1].push_back(edge[0]+1);
}
//首先对每个节点的访问次数进行差分
//之后对于每个节点,考虑让其减半和不让其减半两种情况的代价
dfs1(1,0);
for(auto trip:t)
{
int u=trip[0]+1;
int v=trip[1]+1;
int anc=lca(u,v);
int fa=stjump[anc][0];
cnts[u]++;
cnts[v]++;
cnts[anc]--;
cnts[fa]--;
}
dfs2(1,0);
DP(1,0);
return min(dp[1][0],dp[1][1]);
}
void dfs1(int u,int fa)
{
deep[u]=deep[fa]+1;
stjump[u][0]=fa;
for(int p=1;(1<<p)<=deep[u];p++)
{
stjump[u][p]=stjump[stjump[u][p-1]][p-1];
}
for(auto v:g[u])
{
if(v!=fa)
{
dfs1(v,u);
}
}
}
int lca(int a,int b)
{
if(deep[a]<deep[b])
{
swap(a,b);
}
for(int p=power;p>=0;p--)
{
if(deep[stjump[a][p]]>=deep[b])
{
a=stjump[a][p];
}
}
if(a==b)
{
return a;
}
for(int p=power;p>=0;p--)
{
if(stjump[a][p]!=stjump[b][p])
{
a=stjump[a][p];
b=stjump[b][p];
}
}
return stjump[a][0];
}
void dfs2(int u,int fa)
{
for(auto v:g[u])
{
if(v!=fa)
{
dfs2(v,u);
cnts[u]+=cnts[v];
}
}
}
void DP(int u,int fa)
{
//当前节点用和不用的价格
dp[u][0]=price[u]*cnts[u];
dp[u][1]=(price[u]/2)*cnts[u];
for(auto v:g[u])
{
if(v!=fa)
{
DP(v,u);
//当前节点不用 -> 孩子节点可以用
dp[u][0]+=min(dp[v][0],dp[v][1]);
dp[u][1]+=dp[v][0];
}
}
}
};
这个题就需要一点思考了。
考虑用树上点差分统计每个节点来到的次数,之后对于每个节点,考虑其减半和不减半两种情况,那么这就是一个树型dp问题了。那么就是对于当前来到的节点,如果减半,那么孩子节点就都不能减半,否则就可以减半也可以不减半。
二、树上边差分
树上边差分的公式为:
若想在节点a到节点b的路径上的所有边统一增加v,可以给节点a的点权增加v,节点b的点权增加v,再给两点lca的点权减小2*v,最后在dfs的过程中,在点差分更新的基础上,让父亲到孩子的边加上孩子的点权。
1.Network
poj该淘汰了……
#include<iostream>
#include<stdio.h>
#include<vector>
using namespace std;
/* /\_/\
* (= ._.)
* / > \>
*/
#define dbg(x) cout<<#x<<" "<<x<<endl
#define vdbg(a) cout<<#a<<endl; for(auto x:a) cout<<x<<" ";cout<<endl
#define INF 1e9
#define INFLL 1e18
#define YES cout<<"YES"<<endl;return ;
#define Yes cout<<"Yes"<<endl;return ;
#define NO cout<<"NO"<<endl;return ;
#define No cout<<"No"<<endl;return ;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const int MAXN=1e5+5;
const int LIMIT=20;
int n,m;
vector<vector<int> >g(MAXN);
int power;
int deep[MAXN];
int stjump[MAXN][LIMIT];
int val[MAXN];
int log2()
{
int ans=0;
while((1<<ans)<=(n>>1))
{
ans++;
}
return ans;
}
void dfs1(int u,int fa)
{
deep[u]=deep[fa]+1;
stjump[u][0]=fa;
for(int p=1;(1<<p)<=deep[u];p++)
{
stjump[u][p]=stjump[stjump[u][p-1]][p-1];
}
for(int i=0,v;i<g[u].size();i++)
{
v=g[u][i];
if(v!=fa)
{
dfs1(v,u);
}
}
}
int lca(int a,int b)
{
if(deep[a]<deep[b])
{
swap(a,b);
}
for(int p=power;p>=0;p--)
{
if(deep[stjump[a][p]]>=deep[b])
{
a=stjump[a][p];
}
}
if(a==b)
{
return a;
}
for(int p=power;p>=0;p--)
{
if(stjump[a][p]!=stjump[b][p])
{
a=stjump[a][p];
b=stjump[b][p];
}
}
return stjump[a][0];
}
ll ans;
void dfs2(int u,int fa)
{
for(int i=0,v;i<g[u].size();i++)
{
v=g[u][i];
if(v!=fa)
{
dfs2(v,u);
val[u]+=val[v];
if(val[v]==0)
{
ans+=m;
}
else if(val[v]==1)
{
ans++;
}
}
}
}
void solve()
{
scanf("%d %d",&n,&m);
for(int i=0,u,v;i<n-1;i++)
{
scanf("%d %d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
//考虑对于每条新边(u,v),将u->v路径上的所有边权加1,表示位于一个环上
//最后,对于边权为0的边,删除这条边直接就可以导致不连通,那么就可以和任意新边组合产生m个贡献
//对于边权为1的边,因为处在一个环上,那么删除这条边就必须再删除那条产生环的新边,产生1个贡献
//剩余边不会产生贡献
power=log2();
dfs1(1,0);
for(int i=0,u,v;i<m;i++)
{
scanf("%d %d",&u,&v);
val[u]++;
val[v]++;
int anc=lca(u,v);
val[anc]-=2;
}
dfs2(1,0);
printf("%d",ans);
}
void init()
{
}
int main()
{
int t=1;
//cin>>t;
init();
while(t--)
{
solve();
}
return 0;
}
这个题还是需要一点思考的。
考虑对于每条新边,将两点间路径的边权都加1,表示可以参与构成一个环。之后,对于边权为0的点,其不参与构成任何一个环,所以删除这条边可以直接导致不连通,那么这条边和任意一个新边组合都可以产生贡献。而对于边权为1的点,因为处于一个环上,所以删除这条边必然需要再删除导致其成环的那条新边才能使得不连通,所以会产生一点贡献。而对于其他边权大于等于2的边,因为参与构成多个环,所以怎么删都不行,所以不会产生贡献。
2.运输计划
逆天卡常……
#include <bits/stdc++.h>
using namespace std;
/* /\_/\
* (= ._.)
* / > \>
*/
#define dbg(x) cout<<#x<<" "<<x<<endl
#define vdbg(a) cout<<#a<<endl; for(auto x:a) cout<<x<<" ";cout<<endl
#define INF 1e9
#define INFLL 1e18
#define YES cout<<"YES"<<endl;return ;
#define Yes cout<<"Yes"<<endl;return ;
#define NO cout<<"NO"<<endl;return ;
#define No cout<<"No"<<endl;return ;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const int MAXN=3e5+5;
const int LIMIT=20;
int n,m;
int headEdge[MAXN];
int nextEdge[MAXN<<1];
int toEdge[MAXN<<1];
int weight[MAXN<<1];
int tcnt=1;
void addEdge(int u,int v,int w)
{
nextEdge[tcnt]=headEdge[u];
toEdge[tcnt]=v;
weight[tcnt]=w;
headEdge[u]=tcnt++;
}
int father[MAXN];
void build()
{
for(int i=0;i<MAXN;i++)
{
father[i]=i;
}
}
int find(int i)
{
if(i!=father[i])
{
father[i]=find(father[i]);
}
return father[i];
}
int headQuery[MAXN];
int nextQuery[MAXN<<1];
int toQuery[MAXN<<1];
int pos[MAXN<<1];
int qcnt=1;
void addQuery(int u,int v,int i)
{
nextQuery[qcnt]=headQuery[u];
toQuery[qcnt]=v;
pos[qcnt]=i;
headQuery[u]=qcnt++;
}
int vis[MAXN];
int lca[MAXN];
int dis[MAXN];
int cost[MAXN];
int maxCost;
int quesu[MAXN];
int quesv[MAXN];
void tarjan(int u,int fa,int c)
{
vis[u]=true;
dis[u]=dis[fa]+c;
for(int ei=headEdge[u],v,w;ei>0;ei=nextEdge[ei])
{
v=toEdge[ei];
w=weight[ei];
if(v!=fa)
{
tarjan(v,u,w);
}
}
for(int ei=headQuery[u],v,i;ei>0;ei=nextQuery[ei])
{
v=toQuery[ei];
i=pos[ei];
if(vis[v])
{
lca[i]=find(v);
cost[i]=dis[u]+dis[v]-2*dis[lca[i]];
maxCost=max(maxCost,cost[i]);
}
}
father[u]=fa;
}
int least;
int beyond;
int val[MAXN];
bool dfs(int u,int fa,int c)
{
for(int ei=headEdge[u],v,w;ei>0;ei=nextEdge[ei])
{
v=toEdge[ei];
w=weight[ei];
if(v!=fa)
{
if(dfs(v,u,w))
{
return true;
}
}
}
for(int ei=headEdge[u],v,w;ei>0;ei=nextEdge[ei])
{
v=toEdge[ei];
w=weight[ei];
if(v!=fa)
{
val[u]+=val[v];
}
}
return val[u]==beyond&&c>=least;
}
bool check(int limit)
{
least=maxCost-limit;
memset(val,0,sizeof(val));
beyond=0;
for(int i=1;i<=m;i++)
{
if(cost[i]>limit)
{
val[quesu[i]]++;
val[quesv[i]]++;
val[lca[i]]-=2;
beyond++;
}
}
return beyond==0||dfs(1,0,0);
}
void solve()
{
scanf("%d %d",&n,&m);
for(int i=0,u,v,w;i<n-1;i++)
{
scanf("%d %d %d",&u,&v,&w);
addEdge(u,v,w);
addEdge(v,u,w);
}
//因为要求最大值最小,且若某个时间可以达到,那么大于这个时间的都可以达到,所以可以考虑二分答案
//每次check时,只考虑大于mid的任务,之后进行边差分
//那么若可以达成,就需要存在一条所有超标任务都经过的边,且该边的边权大于等于任务最大值-mid
for(int i=1,u,v;i<=m;i++)
{
scanf("%d %d",&u,&v);
quesu[i]=u;
quesv[i]=v;
addQuery(u,v,i);
addQuery(v,u,i);
}
build();
tarjan(1,0,0);
int l=0;
int r=maxCost;
int mid;
int ans=0;
while(l<=r)
{
mid=l+r>>1;
if(check(mid))
{
ans=mid;
r=mid-1;
}
else
{
l=mid+1;
}
}
printf("%d",ans);
}
void init()
{
}
signed main()
{
int t=1;
//cin>>t;
init();
while(t--)
{
solve();
}
return 0;
}
首先上来的这个思路就不好想,因为要让完成的最大时间最小,且若一个时间可以完成,那么大于这个时间肯定也可以完成,所以可以考虑去二分答案。
所以在每次check的时候,没必要关注本来时间就小于mid的任务,只需要关注时间大于等于mid的任务,去看是否能让其全小于mid。那么每次对于这些任务进行边差分后,只有当存在一条边被所有超限的任务经过,且边权大于等于最大时间和mid的差值,才可以达成。
总结
感觉树上差分和正常差分差不多,主要用于统计一些信息,方便后续整合答案。
END
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)