题意
给定一棵$n$个点($n\leq 10^5$)树,在树上选3个不同的点,要求两两距离相等,求方案数。
动态规划思路
考虑动态规划,设$f[x][i]$表示在点$x$的子树中到$x$距离为$i$的节点个数,$g[x][i]$表示$x$的子树中到lca距离为$d$且lca到$x$距离为$i$的点对个数。
$g[x][i]$状态的设置非常巧妙,也较难理解,具体可见下图:
$g[x][i]$可以理解为在$x$节点上再接长度为$i$的路径就可以凑成合法三元组的对数。
考虑转移:($u$为$x$的儿子节点)
$$g[x][i] += g[u][i + 1] + f[x][i] \times f[u][i - 1]$$
$$f[x][i] += f[u][i - 1] $$
$f$数组比较好理解,$g$数组包含两个部分:第一部分由儿子节点直接传上来,因为儿子节点需要接$i+ 1$长度的情况下,父亲节点就只要接$i$长度;第二部分考虑$x$节点作为点对的lca计算点对数。
对于答案的统计:
如果一个点不再需要往上接:$ans+=g[x][0]$
在两个点合并时:$ans +=g[x][i] \times f[u][i - 1] + g[u][i] \times f[x][i - 1]$
以上就是动态规划的思路,直接做的话时间复杂度为$O(n^2)$。
长链剖分优化
考虑$x$从第一个遍历的儿子那里转移时:
$g[x][i]= g[u][i + 1]$
$f[x][i] = f[i][i - 1]$
可以发现,转移时只有下标偏移。那么我们可不可以直接继承儿子的数组,再移动一下下标,这样就可以实现$O(1)$的转移。对于其它节点再考虑暴力转移。
那么就可以用启发式合并的方式实现快速转移,剩下的问题就是要选好第一个访问的节点,使每个点唯一一次的$O(1)$转移尽可能多地继承信息。很容易想到,对于一个节点,其数组的长度为它的深度,那么每次选择深度最大的那个节点优先访问,长脸剖分的思想就应运而生了。
这么做的时间复杂度为多少呢?我们将一个节点深度最大的儿子节点作为重儿子,其它儿子节点都会成为另一条长链的首端,可以发现只有长链的首端需要暴力向上转移,每条长链转移的次数为长链的长度,所以暴力转移次数为所有长链长度之和,时间复杂度为$O(n)$。
代码实现
因为涉及到坐标的偏移,所以比较好写的方法是采用指针的写法。开一个长度较大的数组,用指针表示每个节点的f,g指向数组的位置。
/*************************************************************************
> File Name: 1.cpp
> Author: Knowledge_llz
> Mail: 925538513@qq.com
> Blog: https://www.cnblogs.com/Knowledge-Pig/
************************************************************************/
#include<bits/stdc++.h>
#define LL long long
#define endl '\n'
using namespace std;
const int maxx = 2e5 + 10;
int n, dep[maxx], son[maxx];
LL tmp[maxx << 2], *f[maxx], *g[maxx], *pos, ans = 0;
vector<int> eg[maxx];
void dfs(int id, int fa){
for(auto u : eg[id]){
if(u == fa) continue;
dfs(u, id);
if(dep[u] >= dep[son[id]]) son[id] = u;
}
dep[id] = dep[son[id]] + 1;
}
void solve(int id, int fa){
if(son[id]){
f[son[id]] = f[id] + 1;
g[son[id]] = g[id] - 1;
solve(son[id], id);
ans += g[id][0];
}
f[id][0] = 1;
for(auto u : eg[id]){
if(u == fa || u == son[id]) continue;
f[u] = pos;
pos += (dep[u] << 1);
g[u] = pos;
pos += (dep[u] << 1);
solve(u, id);
for(int i = 1; i <= dep[u]; ++i){
ans += g[id][i] * f[u][i - 1];
if(i <= dep[u]) ans += g[u][i] * f[id][i - 1];
}
for(int i = 0; i <= dep[u]; ++i){
if(i < dep[u]) g[id][i] += g[u][i + 1];
if(i){
g[id][i] += f[id][i] * f[u][i - 1];
f[id][i] += f[u][i - 1];
}
}
}
}
int main(){
ios::sync_with_stdio(false); cin.tie(0);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt","w", stdout);
#endif
int T; cin >> T;
while(T--){
cin >> n;
for(int i = 1; i <= n; ++i){
eg[i].clear();
ans = dep[i] = son[i] = 0;
}
for(int i = 0; i <= (n + 1 << 2); ++i) tmp[i] = 0;
for(int i = 1, u, v; i < n; ++i){
cin >> u >> v;
eg[u].push_back(v);
eg[v].push_back(u);
}
dfs(1, 0);
pos = tmp;
f[1] = pos;
pos = pos + (dep[1] << 1);
g[1] = pos;
pos = pos + (dep[1] << 1);
solve(1, 0);
cout << ans << endl;
}
return 0;
}
上面指针是在网上找到的一份写法比较精简的代码。因为我一开始不熟悉指针,自己开了一份手动控制下标变换的代码,但是细节过多,实现较麻烦:
/*************************************************************************
> File Name: 1.cpp
> Author: Knowledge_llz
> Mail: 925538513@qq.com
> Blog: https://www.cnblogs.com/Knowledge-Pig/
************************************************************************/
#include<bits/stdc++.h>
#define LL long long
#define endl '\n'
using namespace std;
const int maxx = 2e5 + 10;
int n, dep[maxx], son[maxx];
vector<int> eg[maxx];
void dfs(int id, int fa){
for(auto u : eg[id]){
if(u == fa) continue;
dfs(u, id);
if(dep[u] >= dep[son[id]]) son[id] = u;
}
if(son[id]) dep[id] = dep[son[id]] + 1;
}
int ID[maxx], be[maxx];
LL ans = 0;
vector<LL> f[maxx], g[maxx];
void solve(int x, int fa){
if(son[x]){
solve(son[x], x);
ID[x] = ID[son[x]];
f[ID[x]].push_back(1);
g[ID[x]].push_back(0);
g[ID[x]].push_back(0);
++be[ID[x]];
ans += g[ID[x]][be[ID[x]]];
}
else{
ID[x] = x;
f[x].push_back(1);
g[x].push_back(0);
be[x] = 0;
}
for(auto u : eg[x]){
if(u == fa || u == son[x]) continue;
solve(u, x);
int x_sz = f[ID[x]].size() - 1, u_sz = f[ID[u]].size() - 1;
for(int i = 1; i <= dep[u] + 1; ++i){
if(be[ID[u]] + i < g[ID[u]].size()) ans += g[ID[u]][be[ID[u]] + i] * f[ID[x]][x_sz - (i - 1)];
if(u_sz - (i - 1) < f[ID[u]].size()) ans += g[ID[x]][be[ID[x]] + i] * f[ID[u]][u_sz - (i - 1)];
}
for(int i = 0; i <= dep[u] + 1; ++i){
if(be[ID[u]] + i + 1 < g[ID[u]].size()) g[ID[x]][be[ID[x]] + i] += g[ID[u]][be[ID[u]] + i + 1];
if(i > 0 && u_sz - i + 1 < f[ID[u]].size()){
g[ID[x]][be[ID[x]] + i] += f[ID[x]][x_sz - i] * f[ID[u]][u_sz - i + 1];
}
}
for(int i = u_sz, j = x_sz - 1; i >= 0; --i, --j) f[ID[x]][j] += f[ID[u]][i];
}
}
int main(){
ios::sync_with_stdio(false); cin.tie(0);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt","w", stdout);
#endif
int T; cin >> T;
while(T--){
cin >> n;
for(int i = 1; i <= n; ++i){
eg[i].clear();
f[i].clear();
g[i].clear();
ans = be[i] = dep[i] = son[i] = 0;
}
for(int i = 1; i < n; ++i){
int u, v;
cin >> u >> v;
eg[u].push_back(v);
eg[v].push_back(u);
}
dfs(1, 0);
solve(1, 0);
cout << ans << endl;
}
return 0;
}