
一道淀粉质的模版题,开始是暴力
#include <bits/stdc++.h>
#define up(i,l,r) for(register int i = (l); i <= (r); ++i)
#define dn(i,l,r) for(register int i = (l); i >= (r); --i)
#define ll long long
#define re register
using namespace std; template <typename T> void in(T &x) {
x = ; T f = ; char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -; ch = getchar();}
while( isdigit(ch)) {x = * x + ch - ; ch = getchar();}
x *= f;
} template <typename T> void out(T x) {
if(x < ) x = -x , putchar('-');
if(x > ) out(x/);
putchar(x% + );
}
//--------------------------------------------------------- const int N = ; int n,m; struct edge {
int v,w,nxt;
}e[N<<]; int tot,head[N]; void add(int u,int v,int w) {
e[++tot].v = v; e[tot].w = w; e[tot].nxt = head[u]; head[u] = tot;
} int Tsize,f[N],rt;
int size[N];
bool vis[N]; void get_rt(int u,int fa) {
size[u] = ; f[u] = ;
for(re int i = head[u];i ;i = e[i].nxt) {
int v = e[i].v; if(v == fa || vis[v]) continue;
get_rt(v,u);
size[u] += size[v]; f[u] = max(f[u],size[v]);
}
f[u] = max(f[u],Tsize - size[u]);
if(f[u] < f[rt]) rt = u;
} int dis[N],cdis[N];
int ans[];
int cnt; void get_dis(int u,int fa) {
cdis[++cnt] = dis[u];
for(re int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v; if(v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w; get_dis(v,u);
}
} void calc(int u,int w) {
cnt = ; get_dis(u,);
up(i,,cnt) {
up(j,,cnt) {
if(i != j)
ans[cdis[i] + cdis[j]] += w;
}
}
} void solve(int u) {
vis[u] = ; dis[u] = ; calc(u,);
for(re int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v; if(vis[v]) continue;
dis[v] = e[i].w; calc(v,-);
Tsize = size[v]; rt = ;
get_rt(v,); solve(rt);
}
} int main() {
in(n); in(m); int x,y,w;
up(i,,n-) {
in(x); in(y); in(w);
add(x,y,w); add(y,x,w);
}
Tsize = n,f[] = n+,rt = ;
get_rt(,); solve(rt);
int k;
up(i,,m) {
in(k);
if(ans[k]) printf("AYE\n");
else printf("NAY\n");
}
return ;
}
然后我用了二分
#include <bits/stdc++.h>
#define up(i,l,r) for(register int i = (l); i <= (r); ++i)
#define dn(i,l,r) for(register int i = (l); i >= (r); --i)
#define ll long long
#define re register
using namespace std; template <typename T> void in(T &x) {
x = ; T f = ; char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -; ch = getchar();}
while( isdigit(ch)) {x = * x + ch - ; ch = getchar();}
x *= f;
} template <typename T> void out(T x) {
if(x < ) x = -x , putchar('-');
if(x > ) out(x/);
putchar(x% + );
}
//--------------------------------------------------------- const int N = ; int n,m; struct edge {
int v,w,nxt;
}e[N<<]; int tot,head[N]; void add(int u,int v,int w) {
e[++tot].v = v; e[tot].w = w; e[tot].nxt = head[u]; head[u] = tot;
} int Tsize,f[N],rt;
int size[N];
bool vis[N]; void get_rt(int u,int fa) {
size[u] = ; f[u] = ;
for(re int i = head[u];i ;i = e[i].nxt) {
int v = e[i].v; if(v == fa || vis[v]) continue;
get_rt(v,u);
size[u] += size[v]; f[u] = max(f[u],size[v]);
}
f[u] = max(f[u],Tsize - size[u]);
if(f[u] < f[rt]) rt = u;
} int dis[N];
int cnt;
int belong; struct cur {
int dis,bl;
bool operator < (const cur &x) const {
return dis < x.dis;
}
}c[N]; void get_dis(int u,int fa) {
c[++cnt] = (cur){dis[u],belong};
for(re int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v; if(v == fa || vis[v]) continue;
dis[v] = dis[u] + e[i].w; get_dis(v,u);
}
} int binary(int x) {
int l = ,r = cnt; int res = ;
while(l <= r) {
int mid = (l+r)>>;
if(x <= c[mid].dis) res = mid,r = mid - ;
else l = mid + ;
}
return res;
} bool test[];
int flag;
int k[]; void calc(int u) {
cnt = ; for(re int i = head[u];i;i = e[i].nxt) {
int v = e[i].v; if(vis[v]) continue;
dis[v] = e[i].w; belong = v; get_dis(v,u);
} c[++cnt] = (cur){,};//自己也算一个; sort(c+,c+cnt+); up(i,,m) {
up(j,,cnt) {
if(flag == ) return;
if(test[i]) break;
int r = binary(k[i]-c[j].dis);
if(c[j].bl != c[r].bl && c[j].dis + c[r].dis == k[i]) {
test[i] = ; flag -= i; break;
}
}
}
} void solve(int u) {
if(flag == ) return;
vis[u] = ; dis[u] = ;
calc(u);
for(re int i = head[u]; i; i = e[i].nxt) {
int v = e[i].v; if(vis[v]) continue;
Tsize = size[v]; rt = ;
get_rt(v,); solve(rt);
}
} int main() {
in(n); in(m); int x,y,w; up(i,,n-) {
in(x); in(y); in(w);
add(x,y,w); add(y,x,w);
}
up(i,,m) in(k[i]); flag = (m*(m+))/; Tsize = n,f[] = n+,rt = ;
get_rt(,); solve(rt); up(i,,m) {
if(test[i]) printf("AYE\n");
else printf("NAY\n");
} return ;
}
虽然还可以更快,但我不想打了