Triples(长链剖分+树形DP)


题意

给定一棵$n$个点($n\leq 10^5$)树,在树上选3个不同的点,要求两两距离相等,求方案数。

动态规划思路

考虑动态规划,设$f[x][i]$表示在点$x$的子树中到$x$距离为$i$的节点个数,$g[x][i]$表示$x$的子树中到lca距离为$d$且lca到$x$距离为$i$的点对个数。

$g[x][i]$状态的设置非常巧妙,也较难理解,具体可见下图:

$g[x][i]$可以理解为在$x$节点上再接长度为$i$的路径就可以凑成合法三元组的对数。

考虑转移:($u$为$x$的儿子节点)

$$g[x][i] += g[u][i + 1] + f[x][i] \times f[u][i - 1]$$

$$f[x][i] += f[u][i - 1] $$

$f$数组比较好理解,$g$数组包含两个部分:第一部分由儿子节点直接传上来,因为儿子节点需要接$i+ 1$长度的情况下,父亲节点就只要接$i$长度;第二部分考虑$x$节点作为点对的lca计算点对数。

对于答案的统计:

  1. 如果一个点不再需要往上接:$ans+=g[x][0]$

  2. 在两个点合并时:$ans +=g[x][i] \times f[u][i - 1] + g[u][i] \times f[x][i - 1]$

以上就是动态规划的思路,直接做的话时间复杂度为$O(n^2)$。

长链剖分优化

考虑$x$从第一个遍历的儿子那里转移时:

$g[x][i]= g[u][i + 1]$

$f[x][i] = f[i][i - 1]$

可以发现,转移时只有下标偏移。那么我们可不可以直接继承儿子的数组,再移动一下下标,这样就可以实现$O(1)$的转移。对于其它节点再考虑暴力转移。

那么就可以用启发式合并的方式实现快速转移,剩下的问题就是要选好第一个访问的节点,使每个点唯一一次的$O(1)$转移尽可能多地继承信息。很容易想到,对于一个节点,其数组的长度为它的深度,那么每次选择深度最大的那个节点优先访问,长脸剖分的思想就应运而生了。

这么做的时间复杂度为多少呢?我们将一个节点深度最大的儿子节点作为重儿子,其它儿子节点都会成为另一条长链的首端,可以发现只有长链的首端需要暴力向上转移,每条长链转移的次数为长链的长度,所以暴力转移次数为所有长链长度之和,时间复杂度为$O(n)$。

代码实现

因为涉及到坐标的偏移,所以比较好写的方法是采用指针的写法。开一个长度较大的数组,用指针表示每个节点的f,g指向数组的位置。

/*************************************************************************
> File Name: 1.cpp
> Author: Knowledge_llz
> Mail: 925538513@qq.com
> Blog: https://www.cnblogs.com/Knowledge-Pig/ 
************************************************************************/

#include<bits/stdc++.h>
#define LL long long
#define endl '\n'
using namespace std;
const int maxx = 2e5 + 10;
int n, dep[maxx], son[maxx];
LL tmp[maxx << 2], *f[maxx], *g[maxx], *pos, ans = 0;
vector<int> eg[maxx];
void dfs(int id, int fa){
    for(auto u : eg[id]){
        if(u == fa) continue;
        dfs(u, id);
        if(dep[u] >= dep[son[id]]) son[id] = u;
    }
    dep[id] = dep[son[id]] + 1;
}
void solve(int id, int fa){
	if(son[id]){
		f[son[id]] = f[id] + 1;
		g[son[id]] = g[id] - 1;
		solve(son[id], id);	
		ans += g[id][0];
	}
	f[id][0] = 1;
	
	for(auto u : eg[id]){
		if(u == fa || u == son[id]) continue;
		f[u] = pos;
		pos += (dep[u] << 1);
		g[u] = pos;
		pos += (dep[u] << 1);
		solve(u, id);
		for(int i = 1; i <= dep[u]; ++i){
			ans += g[id][i] * f[u][i - 1];
			if(i <= dep[u]) ans += g[u][i] * f[id][i - 1];
		}
		for(int i = 0; i <= dep[u]; ++i){
			if(i < dep[u]) g[id][i] += g[u][i + 1];
			if(i){
				g[id][i] += f[id][i] * f[u][i - 1];
				f[id][i] += f[u][i - 1];
			}
		}
	}

}
int main(){
	ios::sync_with_stdio(false); cin.tie(0);
#ifndef ONLINE_JUDGE
	freopen("input.txt", "r", stdin);
	freopen("output.txt","w", stdout);
#endif
	int T; cin >> T;
	while(T--){
		cin >> n;
		for(int i = 1; i <= n; ++i){
            eg[i].clear();
            ans = dep[i] = son[i] = 0;
        }
		for(int i = 0; i <= (n + 1 << 2); ++i) tmp[i] = 0;
		for(int i = 1, u, v; i < n; ++i){
			cin >> u >> v;
			eg[u].push_back(v);
			eg[v].push_back(u);
		}
		dfs(1, 0);
		pos = tmp;
		f[1] = pos;
		pos = pos + (dep[1] << 1);
		g[1] = pos;
		pos = pos + (dep[1] << 1);
		solve(1, 0);
		cout << ans << endl;
	}
	return 0;
}

上面指针是在网上找到的一份写法比较精简的代码。因为我一开始不熟悉指针,自己开了一份手动控制下标变换的代码,但是细节过多,实现较麻烦:

/*************************************************************************
> File Name: 1.cpp
> Author: Knowledge_llz
> Mail: 925538513@qq.com
> Blog: https://www.cnblogs.com/Knowledge-Pig/ 
************************************************************************/

#include<bits/stdc++.h>
#define LL long long
#define endl '\n'
using namespace std;
const int maxx = 2e5 + 10;
int n, dep[maxx], son[maxx];
vector<int> eg[maxx];
void dfs(int id, int fa){
    for(auto u : eg[id]){
        if(u == fa) continue;
        dfs(u, id);
        if(dep[u] >= dep[son[id]]) son[id] = u;
    }  
    if(son[id]) dep[id] = dep[son[id]] + 1;
}
int ID[maxx], be[maxx];
LL ans = 0;
vector<LL> f[maxx], g[maxx];
void solve(int x, int fa){
    if(son[x]){
        solve(son[x], x);
        ID[x] = ID[son[x]];
        f[ID[x]].push_back(1);
        g[ID[x]].push_back(0);
        g[ID[x]].push_back(0);
        ++be[ID[x]];
        ans += g[ID[x]][be[ID[x]]];
    }
    else{
        ID[x] = x;
        f[x].push_back(1);
        g[x].push_back(0);
        be[x] = 0;
    }
    for(auto u : eg[x]){
        if(u == fa || u == son[x]) continue;
        solve(u, x);
        int x_sz = f[ID[x]].size() - 1, u_sz = f[ID[u]].size() - 1;
        for(int i = 1; i <= dep[u] + 1; ++i){
            if(be[ID[u]] + i < g[ID[u]].size()) ans += g[ID[u]][be[ID[u]] + i] * f[ID[x]][x_sz - (i - 1)]; 
            if(u_sz - (i - 1) < f[ID[u]].size()) ans += g[ID[x]][be[ID[x]] + i] * f[ID[u]][u_sz - (i - 1)];
        }
        for(int i = 0; i <= dep[u] + 1; ++i){
            if(be[ID[u]] + i + 1 < g[ID[u]].size()) g[ID[x]][be[ID[x]] + i] += g[ID[u]][be[ID[u]] + i + 1];
            if(i > 0 && u_sz - i + 1 < f[ID[u]].size()){
                g[ID[x]][be[ID[x]] + i] += f[ID[x]][x_sz - i] * f[ID[u]][u_sz - i + 1];
            }
        }
        for(int i = u_sz, j = x_sz - 1; i >= 0; --i, --j) f[ID[x]][j] += f[ID[u]][i];
    }

}
int main(){
    ios::sync_with_stdio(false); cin.tie(0);
#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt","w", stdout);
#endif
    int T; cin >> T;
    while(T--){
        cin >> n;
        for(int i = 1; i <= n; ++i){
            eg[i].clear();
            f[i].clear();
            g[i].clear();
            ans = be[i] = dep[i] = son[i] = 0;
        }
        for(int i = 1; i < n; ++i){
            int u, v;
            cin >> u >> v;
            eg[u].push_back(v);
            eg[v].push_back(u);
        }
        dfs(1, 0);
        solve(1, 0);
        cout << ans << endl;
    }
    return 0;
}

文章作者: Knowledge
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Knowledge !
  目录