题意:
给出一棵树,每个顶点上有个\(2 \times 2\)的矩阵,矩阵有两种操作:
- 顺时针旋转90°,花费是2
- 将一种矩阵替换为另一种矩阵,花费是10
树上有一种操作,将一条路经上的所有矩阵都变为给出的矩阵,并输出最小花费。
分析:
矩阵可以分为两类共6种,一类是两个1相邻的矩阵共4种;一类是两个1在对角线的矩阵共2种。
同一类矩阵可以通过旋转操作得到,否则只能用替换。
事先计算好每种矩阵转换到另外一种矩阵的最少花费,然后树链剖分再用线段树维护就好了。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 20000 + 10;
const int maxnode = maxn * 4;
void read(int& x) {
x = 0;
char c = ' ';
while(c < '0' || c > '9') c = getchar();
while('0' <= c && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
}
int n, tot;
vector<int> G[maxn];
int son[maxn], sz[maxn], top[maxn], dep[maxn], fa[maxn];
int id[maxn], pos[maxn];
void dfs(int u) {
sz[u] = 1; son[u] = 0;
for(int v : G[u]) {
if(v == fa[u]) continue;
fa[v] = u;
dep[v] = dep[u] + 1;
dfs(v);
sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int tp) {
id[u] = ++tot;
pos[tot] = u;
top[u] = tp;
if(son[u]) dfs2(son[u], tp);
for(int v : G[u]) {
if(v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int cost[6][6];
int cntv[maxnode][6], setv[maxnode];
int readMat() {
int a[4];
read(a[0]); read(a[1]); read(a[3]); read(a[2]);
for(int i = 0; i < 4; i++)
if(a[i] == 1 && a[(i+1)%4] == 1) return i;
if(a[0] == 1) return 4;
return 5;
}
int h[maxn];
void pushdown(int o, int L, int R) {
if(setv[o] != -1) {
int M = (L + R) / 2;
setv[o<<1] = setv[o<<1|1] = setv[o];
for(int i = 0; i < 6; i++)
cntv[o<<1][i] = cntv[o<<1|1][i] = 0;
cntv[o<<1][setv[o]] = M - L + 1;
cntv[o<<1|1][setv[o]] = R - M;
setv[o] = -1;
}
}
void pushup(int o) {
for(int i = 0; i < 6; i++)
cntv[o][i] = cntv[o<<1][i] + cntv[o<<1|1][i];
}
void build(int o, int L, int R) {
if(L == R) { cntv[o][h[pos[L]]] = 1; return; }
int M = (L + R) / 2;
build(o<<1, L, M);
build(o<<1|1, M+1, R);
pushup(o);
}
int q[6];
void update(int o, int L, int R, int qL, int qR, int v) {
if(qL <= L && R <= qR) {
setv[o] = v;
for(int i = 0; i < 6; i++) q[i] += cntv[o][i];
for(int i = 0; i < 6; i++) cntv[o][i] = 0;
cntv[o][v] = R - L + 1;
return;
}
pushdown(o, L, R);
int M = (L + R) / 2;
if(qL <= M) update(o<<1, L, M, qL, qR, v);
if(qR > M) update(o<<1|1, M+1, R, qL, qR, v);
pushup(o);
}
void UPDATE(int u, int v, int val) {
memset(q, 0, sizeof(q));
int t1 = top[u], t2 = top[v];
while(t1 != t2) {
if(dep[t1] < dep[t2]) { swap(u, v); swap(t1, t2); }
update(1, 1, n, id[t1], id[u], val);
u = fa[t1]; t1 = top[u];
}
if(dep[u] > dep[v]) swap(u, v);
update(1, 1, tot, id[u], id[v], val);
}
int main()
{
for(int i = 0; i < 6; i++)
for(int j = 0; j < 6; j++) {
if(i == j) { cost[i][j] = 0; continue; }
int a = ((i >> 2) & 1), b = ((j >> 2) & 1);
if(a ^ b) cost[i][j] = 10;
else if(!a) cost[i][j] = ((((j - i) % 4) + 4) % 4) * 2;
else cost[i][j] = 2;
}
int kase; scanf("%d", &kase);
while(kase--) {
read(n);
for(int i = 1; i <= n; i++) G[i].clear();
for(int i = 1; i < n; i++) {
int u, v; read(u); read(v);
G[u].push_back(v);
G[v].push_back(u);
}
sz[0] = fa[1] = 0;
dfs(1);
tot = 0;
dfs2(1, 1);
//build segment tree
memset(cntv, 0, sizeof(cntv));
memset(setv, -1, sizeof(setv));
for(int i = 1; i <= n; i++) h[i] = readMat();
build(1, 1, n);
//Queries
int _; scanf("%d", &_);
while(_--) {
int u, v, val; read(u); read(v);
val = readMat();
UPDATE(u, v, val);
int ans = 0;
for(int i = 0; i < 6; i++) ans += q[i] * cost[i][val];
printf("%d\n", ans);
}
}
return 0;
}