bzoj4381 [POI2015]Odwiedziny
给定一棵带点权的树,每次询问在 \(u\) 到 \(v\) 的路径上,每次走 \(k\) 步,如果最后不足 \(k\) 步就走到了 \(v\) ,则会一步走到 \(v\) ,求每次行走的经过的点的点权和
\(n\leq5\times10^4,\ a_i\leq10^4\)
根号分治
考虑根号分治,如果 \(k>\sqrt n\) ,每次暴力枚举走到的节点,反之,预处理所有点每次走 \(i(i\leq\sqrt n)\) 步直到超过根的深度所经过的点权和(与 \(i\) 到根的路径不同),询问时计算贡献。
对于 \(k>\sqrt n\) 的询问,需特判 \(u=v\) 与 \(v=lca\) 的情况,并且需要快速查询一个点的 \(k\) 级祖先,用倍增即可,时间复杂度 \(O(\sqrt n\log n)\)
对于 \(k\leq\sqrt n\) 的询问,同样要特判 \(u=v\) 与 \(v=lca\) 的情况,还需注意 \(lca\to v\) 的路径的贡献。预处理时间复杂度 \(O(n\sqrt n\log n)\) ,单次查询 \(O(\log n)\)
综上,时间复杂度 \(O(n\sqrt n\log n)\) ,貌似把倍增换成长链剖分可以做到 \(O(n\sqrt n)\) ?
代码(开 C++11)
#include <bits/stdc++.h>
using namespace std;const int maxn = 5e4 + 10;
int n, bsz, a[maxn], dep[maxn], fa[16][maxn], sum[225][maxn];vector <int> e[maxn];int findlca(int u, int v) {if (dep[u] < dep[v]) swap(u, v);for (int i = 15; ~i; i--) {if (dep[u] - (1 << i) >= dep[v]) {u = fa[i][u];}}if (u == v) return u;for (int i = 15; ~i; i--) {if (fa[i][u] != fa[i][v]) {u = fa[i][u], v = fa[i][v];}}return fa[0][u];
}int findanc(int u, int k) {for (int i = 0; i < 16; i++) {if (k >> i & 1) u = fa[i][u];}return u;
}void dfs1(int u, int f) {fa[0][u] = f;dep[u] = dep[f] + 1;for (int i = 1; i < 16; i++) {fa[i][u] = fa[i - 1][fa[i - 1][u]];}for (int v : e[u]) {if (v != f) dfs1(v, u);}
}void dfs2(int u, int f) {for (int i = 1; i <= bsz; i++) {sum[i][u] = sum[i][findanc(u, i)] + a[u];}for (int v : e[u]) {if (v != f) dfs2(v, u);}
}int query1(int u, int v, int k) {if (u == v) return a[u];int lca = findlca(u, v), res = 0;int delta = dep[u] - dep[lca];int anc = findanc(u, delta - delta % k);res = sum[k][u] - sum[k][anc] + a[anc], u = anc;if (v == lca) return res;int s = dep[u] + dep[v] - dep[lca] - dep[lca];if (s > k && s % k) {res += a[v], v = findanc(v, s % k);}s = dep[u] + dep[v] - dep[lca] - dep[lca];int tmp = s > k ? findanc(v, s - k) : v;res += sum[k][v] - sum[k][tmp] + a[tmp];return res;
}int query2(int u, int v, int k) {if (u == v) return a[u];int lca = findlca(u, v), res = a[u] + a[v];while (1) {int anc = findanc(u, k);if (dep[anc] < dep[lca] || (v == lca && dep[anc] == dep[v])) {break;}res += a[anc], u = anc;}v = findanc(v, (dep[u] + dep[v] - dep[lca] - dep[lca]) % k);while (1) {int anc = findanc(v, k);if (dep[anc] <= dep[lca]) {break;}res += a[anc], v = anc;}return res;
}int main() {scanf("%d", &n);bsz = sqrt(n);for (int i = 1; i <= n; i++) {scanf("%d", a + i);}for (int i = 1; i < n; i++) {int u, v;scanf("%d %d", &u, &v);e[u].push_back(v), e[v].push_back(u);}dfs1(1, 0), dfs2(1, 0);static int step[maxn];for (int i = 1; i <= n; i++) {scanf("%d", step + i);}for (int i = 2; i <= n; i++) {int u = step[i - 1], v = step[i], k;scanf("%d", &k);printf("%d\n", k <= bsz ? query1(u, v, k) : query2(u, v, k));}return 0;
}