BZOJ3091 城市旅行

Description

给一颗以 $1$ 为根的有根树,维护以下操作

  1. 连接 $(u,v)$ 这条边
  2. 删除 $(u,v)$ 这条边
  3. 给 $u$ 到 $v$ 的链上每个点加上一个数
  4. 求在 $(u,v)$ 上任意选两个点它们之间的权值和的期望

$n, m \leq 50000, a_i \leq 10^6$

Solution

前三个操作就是 LCT 板子,考虑如何在 LCT 上维护 4 操作

为了方便,设这个路径是 $a_1, a_2, a_3, \cdots, a_{siz}$ ,其中 $siz$ 是长度

考虑每个点的贡献,易得我们要求的期望值 $=\frac{\sum\limits_{i=1}^{siz} i (siz - i + 1)a_i}{\frac{siz(siz+1)}{2}}$

显然这个分母很好搞,只需要考虑怎么在 LCT 上维护分子,或者说在平衡树上。

也就是说,如果知道左子和右子的答案如何更新出这个点的答案

设左子表示 $a_1, a_2, \cdots, a_p$, 该点的值是 $a_{p+1}$ ,右子表示 $a_{p+2}, \cdots, a_{siz}$

可以得到:左子的 $siz_0 = p$,右子的 $siz_1 = siz - p - 1$

改点要的答案减去左子的答案减去右子的答案便是

$\sum\limits_{i=1}^{siz}i(siz - i + 1)a_i - \sum\limits_{i=1}^{p}i(p-i+1)a_i-\sum\limits_{i=p+2}^{siz} (i-p-1)(siz - i + 1)a_i$

$=\sum\limits_{i=1}^{p} i(siz-p)a_i+a_{p+1}(p+1)(siz-p)+\sum\limits_{i=p+2}^{siz}(p+1)(siz-i+1)a_i$

根据上面得到的 $siz_0=p,siz_1=siz-p-1$ 简单化简一下可以得到

$=(siz_1+1)\sum\limits_{i=1}^{siz_0}i\cdot a_i+a_{siz_0+1}(siz_0+1)(siz_1+1)+(siz_0+1)\sum\limits_{i=p+2}^{siz}(siz - i +1)a_i$

到这里应该你已经知道怎么做了..

为了清楚,再令

$b_1, b_2, \cdots,b_{siz_b}$ 是左子的, $c_1, c_2, \cdots,c_{siz_c}$ 是右子的,$d$ 是这个点本身的值。那么可以化简成简单清楚对称的形式:

$=(siz_c+1)\sum\limits_{i=1}^{siz_b}i\cdot b_i+d(siz_b+1)(siz_c+1)+(siz_b+1)\sum\limits_{i=1}^{siz_c}(siz_c-i+1)c_i$

你只需要每个点再维护两个值:

$ls=\sum\limits_{i=1}^{siz}i\cdot a_i$ 和 $rs=\sum\limits_{i=1}^{siz}(siz - i +1)a_i$

就可以从左右两个儿子得到自己的值

这两个东西维护还是比较简单的..具体的话就是再维护一个 $s$ 为子树里所有数的和然后令 $b,c$ 是左右两个儿子,那么有

$ls = ls_b+d\cdot(siz_b+1)+ls_c+s_c (siz_b+1)$

$rs=rs_c+d\cdot(siz_c+1)+rs_b+s_b(siz_c+1)$

就这样维护

以上是如何用左右儿子的信息得到自己,再来考虑链加的问题

一条链加上一个数 $x$ ,那么会如何影响我们维护的值?

  • 对于 $s$:$s = s + siz\cdot x$
  • 对于 $ls$:$ls = ls + \sum\limits_{i=1}^{siz}i \cdot x = ls + \frac{siz(siz+1)}{2}\cdot x$
  • 对于 $rs$:和 ls 一样 $rs = rs+\frac{siz(siz+1)}{2}\cdot x$
  • 对于最后的答案 $S$:$S = S + \sum\limits_{i=1}^{siz} i \cdot (siz - i +1)\cdot x$ 通过简单计算可得 $S= S+\frac{siz(siz+1)(siz+2)}{6}\cdot x$
  • 对于自己的值:直接加上 $x$ (废话)

然后 LCT 板子套一套就做完了

注意事项:

  • 翻转的时候需要 swap(ls, rs)
  • 两个点之间是联通的时候才执行链加操作(坑死我了)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/**
* Author: AcFunction
* Date: 2019-02-17 11:17:08
* Email: 3486942970@qq.com
**/

#include <bits/stdc++.h>
#define ll long long

using namespace std;

const int N = 200200;
const ll INF = (ll)1e18;

int n, m;
ll a[N];
struct node {

int rev;
ll d, s, ls, rs, s1, add, siz;
node *ch[2], *prt;

int isr() {
return (!prt) || ( prt->ch[0] != (this) && prt->ch[1] != (this) );
}

int dir() {
return prt->ch[1] == (this);
}
void setc(node *p, int k) {
(this)->ch[k] = p;
if(p) p->prt = (this);
}

void setr() {
rev ^= 1;
swap(ls, rs);
swap(ch[0], ch[1]);
}

void seta(ll x) {
d += x, add += x; s += siz * x;
ls += siz * (siz + 1) / 2 * x;
rs += siz * (siz + 1) / 2 * x;
s1 += siz * (siz + 1) * (siz + 2) / 6 * x;
}

void upd() {
siz = 1, s = d;
if(ch[0]) siz += ch[0]->siz, s += ch[0]->s;
if(ch[1]) siz += ch[1]->siz, s += ch[1]->s;
if(ch[0] && ch[1]) {
ls = ch[0]->ls + d * (ch[0]->siz + 1) + ch[1]->ls + ch[1]->s * (ch[0]->siz + 1);
rs = ch[1]->rs + d * (ch[1]->siz + 1) + ch[0]->rs + ch[0]->s * (ch[1]->siz + 1);
s1 = ch[0]->s1 + ch[1]->s1;
s1 += ch[0]->ls * (ch[1]->siz + 1);
s1 += ch[1]->rs * (ch[0]->siz + 1);
s1 += d * (ch[0]->siz + 1) * (ch[1]->siz + 1);
} else if(ch[0]) {
ls = ch[0]->ls + d * (ch[0]->siz + 1);
rs = d + ch[0]->rs + ch[0]->s;
s1 = ch[0]->s1 + ch[0]->ls + d * (ch[0]->siz + 1);
} else if(ch[1]) {
ls = d + ch[1]->ls + ch[1]->s;
rs = d * (ch[1]->siz + 1) + ch[1]->rs;
s1 = ch[1]->s1 + ch[1]->rs + d * (ch[1]->siz + 1);
} else {
ls = rs = s1 = d;
}
}

void push() {
if(rev) {
if(ch[0]) ch[0]->setr();
if(ch[1]) ch[1]->setr();
rev = 0;
}
if(add) {
if(ch[0]) ch[0]->seta(add);
if(ch[1]) ch[1]->seta(add);
add = 0;
}
}

} pool[N * 2], *P[N], *cur = pool;

node *New(ll d) {
node *p = cur++;
p->d = d, p->ls = p->rs = d;
p->s = p->s1 = d;
p->prt = p->ch[0] = p->ch[1] = 0;
p->siz = 1;
return p;
}

void rotate(node *p) {
node *prt = p->prt; int k = p->dir();
if(!prt->isr()) prt->prt->setc(p, prt->dir());
else p->prt = prt->prt; prt->setc(p->ch[!k], k);
p->setc(prt, !k); prt->upd(); p->upd();
}

node *sta[N]; int top;
void splay(node *p) {
node *q = p;
while(1) {
sta[++top] = q;
if(q->isr()) break ;
q = q->prt;
}
while(top)
(sta[top--])->push();
while(!p->isr()) {
if(p->prt->isr()) rotate(p);
else if(p->dir() == p->prt->dir()) rotate(p->prt), rotate(p);
else rotate(p), rotate(p);
} p->upd();
}


node *access(node *p) {
node *q = 0;
for(; p; p = p->prt) {
splay(p);
p->ch[1] = q;
(q = p)->upd();
} return q;
}

inline void mkroot(node *p) { access(p); splay(p); p->setr(); p->push(); }
inline void split (node *p, node *q) { mkroot(p); access(q); splay(p); }
inline void link (node *p, node *q) { mkroot(p); mkroot(q); q->prt = p; }
inline void cut (node *p, node *q) { split(p, q); p->ch[1] = q->prt = 0; }
inline node *find(node *p) { access(p); splay(p); while(p->ch[0]) p = p->ch[0]; return p; }

inline ll gcd(ll a, ll b) {
return !b ? a : gcd(b, a % b);
}

int main() {
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
P[i] = New(a[i]);
}
for(int i = 1; i < n; i++) {
int u, v; scanf("%d %d", &u, &v);
link(P[u], P[v]);
}
for(int i = 1; i <= m; i++) {
int op, u, v; ll d;
scanf("%d %d %d", &op, &u, &v);
if(op == 1) if(find(P[u]) == find(P[v])) cut(P[u], P[v]);
if(op == 2) if(find(P[u]) != find(P[v])) link(P[u], P[v]);
if(op == 3) {
scanf("%lld", &d);
if(find(P[u]) != find(P[v])) continue ; // important!
split(P[u], P[v]), P[u]->seta(d);
}
if(op == 4) {
if(find(P[u]) != find(P[v])) {
printf("-1\n");
continue ;
}
split(P[u], P[v]);
ll ans = P[u]->s1;
ll t = P[u]->siz * (P[u]->siz + 1) / 2;
ll g = gcd(ans, t);
printf("%lld/%lld\n", ans / g, t / g);
}
}
return 0;
}