【HDOJ】4358 Boring counting

时间:2021-05-21 23:42:11

基本思路是将树形结构转线性结构,因为查询的是从任意结点到叶子结点的路径。
从而将每个查询转换成区间,表示从该结点到叶子结点的路径。
离线做,按照右边界升序排序。
利用树状数组区间修改。
树状数组表示有K个数据的数量,利用pos进行维护。
假设现有的sz >= K, 那么需要对区间进行修改。

 /* 4358 */
#include <iostream>
#include <sstream>
#include <string>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
#include <deque>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <ctime>
#include <cstring>
#include <climits>
#include <cctype>
#include <cassert>
#include <functional>
#include <iterator>
#include <iomanip>
using namespace std;
//#pragma comment(linker,"/STACK:102400000,1024000") #define sti set<int>
#define stpii set<pair<int, int> >
#define mpii map<int,int>
#define vi vector<int>
#define pii pair<int,int>
#define vpii vector<pair<int,int> >
#define rep(i, a, n) for (int i=a;i<n;++i)
#define per(i, a, n) for (int i=n-1;i>=a;--i)
#define clr clear
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define all(x) (x).begin(),(x).end()
#define SZ(x) ((int)(x).size())
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1 typedef struct {
int v, nxt;
} edge_t; typedef struct node_t {
int w, id; friend bool operator< (const node_t& a, const node_t& b) {
if (a.w == b.w)
return a.id < b.id;
return a.w < b.w;
} } node_t; typedef struct ques_t {
int l, r, id;
} ques_t; const int maxn = 1e5+;
const int maxv = maxn;
const int maxe = maxv * ;
int head[maxv], l;
edge_t E[maxe];
int Beg[maxn], End[maxn];
int val[maxn], W[maxn];
node_t nd[maxn];
vi pvc[maxn];
int dfs_clock;
int n, K;
int a[maxn];
ques_t Q[maxn];
int ans[maxn]; bool compq (const ques_t& a, const ques_t& b) {
if (a.r == b.r)
return a.l < b.l;
return a.r < b.r;
} void init() {
memset(head, -, sizeof(head));
memset(a, , sizeof(a));
dfs_clock = l = ;
} void addEdge(int u, int v) {
E[l].v = v;
E[l].nxt = head[u];
head[u] = l++; E[l].v = u;
E[l].nxt = head[v];
head[v] = l++;
} void dfs(int u, int fa) {
int v, k; Beg[u] = ++dfs_clock;
val[dfs_clock] = W[u];
for (k=head[u]; k!=-; k=E[k].nxt) {
v = E[k].v;
if (v == fa)
continue;
dfs(v, u);
}
End[u] = dfs_clock;
} int lowest(int x) {
return x & -x;
} int sum(int x) {
int ret = ; while (x) {
ret += a[x];
x -= lowest(x);
} return ret;
} void update(int x, int delta) {
while (x <= n) {
a[x] += delta;
x += lowest(x);
}
} void solve() {
int q;
int u, v; sort(nd+, nd++n);
int cnt = ;
W[nd[].id] = cnt;
rep(i, , n+) {
if (nd[i].w==nd[i-].w) {
W[nd[i].id] = cnt;
} else {
W[nd[i].id] = ++cnt;
}
} dfs(, ); rep(i, , cnt+) {
pvc[i].clr();
pvc[i].pb();
} scanf("%d", &q);
rep(i, , q) {
scanf("%d", &u);
Q[i].l = Beg[u];
Q[i].r = End[u];
Q[i].id = i;
} sort(Q, Q+q, compq);
int sz;
int j = ; rep(i, , n+) {
pvc[val[i]].pb(i);
sz = SZ(pvc[val[i]]) - ;
if (sz >= K) {
if (sz > K) {
update(pvc[val[i]][sz-K-]+, -);
update(pvc[val[i]][sz-K]+, );
}
update(pvc[val[i]][sz-K]+, );
update(pvc[val[i]][sz-K+]+, -);
}
while (j<q && Q[j].r==i) {
ans[Q[j].id] = sum(Q[j].l);
++j;
}
} rep(i, , q)
printf("%d\n", ans[i]);
} int main() {
ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
freopen("data.in", "r", stdin);
freopen("data.out", "w", stdout);
#endif int t;
int u, v; scanf("%d", &t);
rep(tt, , t+) {
init();
scanf("%d %d", &n, &K);
rep(i, , n+) {
scanf("%d", &W[i]);
nd[i].id = i;
nd[i].w = W[i];
}
rep(i, , n) {
scanf("%d %d", &u, &v);
addEdge(u, v);
}
printf("Case #%d:\n", tt);
solve();
if (tt != t)
putchar('\n');
} #ifndef ONLINE_JUDGE
printf("time = %d.\n", (int)clock());
#endif return ;
}

数据发生器。

 from copy import deepcopy
from random import randint, shuffle
import shutil
import string def GenDataIn():
with open("data.in", "w") as fout:
t = 10
bound = 10**9
fout.write("%d\n" % (t))
for tt in xrange(t):
n = randint(100, 200)
K = randint(1, 5)
fout.write("%d %d\n" % (n, K))
ust = [1]
vst = range(2, n+1)
L = []
for i in xrange(n):
x = randint(1, 100)
L.append(x)
fout.write(" ".join(map(str, L)) + "\n")
for i in xrange(1, n):
idx = randint(0, len(ust)-1)
u = ust[idx]
idx = randint(0, len(vst)-1)
v = vst[idx]
ust.append(v)
vst.remove(v)
fout.write("%d %d\n" % (u, v))
q = n
fout.write("%d\n" % (q))
L = range(1, n+1)
shuffle(L)
fout.write("\n".join(map(str, L)) + "\n") def MovDataIn():
desFileName = "F:\eclipse_prj\workspace\hdoj\data.in"
shutil.copyfile("data.in", desFileName) if __name__ == "__main__":
GenDataIn()
MovDataIn()