0%

EOJ-3681.中位数

这道题有点折磨。首先是思路不太简单,用 ChatGPT 的话来说,就是有着“三层‘逼迫’”——当然,对于大佬来说可能也是常规了。然后是 Python 的问题,一直超内存,最后是 Gemini 优化后终于过了。因此在这里记录一下。

题目

题目链接:https://acm.ecnu.edu.cn/problem/3681/

思路

先给出一个判断一个数与一个数组的中位数之间大小关系的方法:

假设一个数为 $X$,数组的中位数是 $M$,将数组中大于等于 $X$ 的数看作 $1$ 、小于 $X$ 的数看作 $-1$,则如果整个数组的和小于 $0$ ,则 $X \gt M$,而如果和大于等于 $0$,说明 $X \le M$。

因为这道题里中位数是排序后的第 $\lfloor \frac{n}{2} \rfloor + 1$ 个,所以当 $X == M$ 的时候,和大于或者等于 $0$ 。

对于这道题,直接求非常难,转换为判定性问题,即假定结果是 $X$,是否存在一条从 $1$ 到 $n$ 的路径中位数 $M >= X$ 。

怎样判断存在?这里要用 拓扑排序动态规划 了。首先用前面的方法假设一个数 $X$ 进行权值转换,定义 $dp[i]$ 为结点 $1$ 到 $i$ 的最大权值和(从 $1$ 到 $i$ 会有好多路径,每个路径都会有一个权值和,取最大的那个作为 $dp[i]$ ),若 $dp[n] >= 0$,说明存在一条 $1$ 到 $n$ 的路径中位数 $M \ge X$ ,否则说明 $M \lt X$ 。

那么,对于是否存在一条从 $1$ 到 $n$ 的路径中位数 $M >= X$ ,

  • 若存在,则小于 $X$ 的数也存在这样的路径。不用找,就这一条就行。为了找到答案,变大 $X$ 。
  • 若不存在,则大于 $X$ 的数也不存在这样的路径。因为比 $X$ 越大,路径上的 $-1$ 也就越多,路径和就越不容易大于等于 $0$ 。为了找到答案,减小 $X$ 。

动态规划的转移方式:$dp[i] = max(dp[i], dp[j] + w)$,其中 $w$ 是结点 $i$ 转换后的权值,结点 $j$ 有边指向 $i$ ,即每次更新所有 $j$ 的后继。

初始值:$dp[i] = 转换后的权值$ ,其余都是 $-inf$。

这里有几个问题:

  1. 拓扑排序的作用是什么?

    按拓扑序进行动态规划,保证每个点向后传递计算后继结点 $i$ 的 $dp[i]$ 的时候,该点值已经算完不会再被更新。

  2. 动态规划的转移为什么是取 最大

    虽然说题目要算的是 最大中位数 ,但我觉得这和动态规划里要取最大值没有太大关系。主要是因为我们要算 存在性 ,而 $dp[n]$ 越大越能保证存在,因此要取 $max$ 。

代码

C++

我的版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/*
转换为判定性问题:
假设中位数为 M,是否存在一条路径中位数 >= M。
若存在,小于 M 的也能存在。
如何判断是否存在一条路径中位数 >= M?
将 >= M 的点权置为 1,< M 的置为 -1
用动态规划的思想,dp[v] = max(dp[v], dp[u] + w[v]) (u 指向 v)
若 dp[n] >= 0,则存在。
*/

#include <bits/stdc++.h>
using namespace std;

const int NEG_INF = -0x3f3f3f3f;
int n, m;
vector<int> a; // 点权
vector<vector<int>> g; // 图
vector<int> topo; // 拓扑序列
vector<int> indeg;// 入度
vector<int> dp; // check 中的 dp 数组
vector<int> w; // check 中的权值数组

bool check(int m){
// 权值转换
for(int i = 1; i <= n; i++) w[i] = a[i] >= m ? 1 : -1;
// 初始化 dp
fill(dp.begin(), dp.end(), NEG_INF);
dp[1] = w[1];

for(int u : topo){
// 最好有这个 continue
if(dp[u] == NEG_INF) continue;
for(int v : g[u]){
// dp 转移
if(dp[u] + w[v] > dp[v]){
dp[v] = dp[u] + w[v];
}
}
}
return dp[n] >= 0;
}

int main(){
// 输入和初始化
cin >> n >> m;
a.resize(n + 1);
indeg.assign(n + 1, 0);
g.resize(n + 1);
dp.resize(n + 1);
w.resize(n + 1);

for(int i = 1; i <= n; i++) cin >> a[i];
for(int i = 0; i < m; i++){
int u, v; cin >> u >> v;
g[u].emplace_back(v);
indeg[v] += 1;
}

// 拓扑排序
deque<int> q;
for(int i = 1; i <= n; i++){
if(indeg[i] == 0) q.emplace_back(i);
}
while(!q.empty()){
int u = q.front();
q.pop_front();
topo.emplace_back(u);
for(int v: g[u]){
indeg[v]--;
if(indeg[v] == 0)
q.emplace_back(v);
}
}

int left = *min_element(a.begin(), a.end());
int right = *max_element(a.begin(), a.end());
while (left <= right){
int mid = left + ((right - left) >> 1);
if (check(mid)) left = mid + 1;
else right = mid - 1;
}

cout << right;
return 0;
}

优化版

优化点(主要是前两个):

  1. 输入。数量级大了,提速明显。

    1
    2
    ios::sync_with_stdio(false);
    cin.tie(nullptr);`

    原因:

    • sync_with_stdio(false) 砍掉了 C++ 与 C 语言输入输出的 兼容同步,让 cin 跑得飞快。
    • cin.tie(nullptr) 解除了读入与输出的自动刷新绑定,避免了每读一个数就强行刷一次屏幕的性能浪费。
    • 两者配合使用能让 C++ 获得超越 scanf/printf 的极致读写速度,是处理大数据量的标配
  1. 上面的版本是在权值的值域上二分,即二分的范围是权值的 $[\text{最大值}, \text最小值]$ ,优化成在去重、排序后的权值数组上索引二分,减少 check 次数。

  2. topo.reserve(n + 1); ,提前“预定” $n+1$ 个空间,但不先真开空间。

  3. 用 Lambda 函数,这样各种变量,都在 main 函数里,会好一点点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>
using namespace std;
const int NEG_INF = -0x3f3f3f3f;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);

int n, m; cin >> n >> m;
vector<int> a(n + 1);
vector<vector<int>> g(n + 1);
vector<int> indeg(n + 1);

for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 0; i < m; i++) {
int u, v;
cin >> u >> v;
g[u].emplace_back(v);
indeg[v] += 1;
}

// 拓扑排序
deque<int> q;
vector<int> topo;
topo.reserve(n + 1);
for (int i = 1; i <= n; i++) {
if (indeg[i] == 0) q.emplace_back(i);
}
while (!q.empty()) {
int u = q.front();
q.pop_front();
topo.emplace_back(u);
for (int v : g[u]) {
indeg[v]--;
if (indeg[v] == 0)
q.emplace_back(v);
}
}

vector<int> dp(n + 1);
vector<int> w(n + 1);
auto check = [&](int m) {
// 权值转换
for (int i = 1; i <= n; i++) w[i] = a[i] >= m ? 1 : -1;
// 初始化 dp
fill(dp.begin(), dp.end(), NEG_INF);
dp[1] = w[1];

for (int u : topo) {
// 最好有这个 continue
if (dp[u] == NEG_INF) continue;
for (int v : g[u]) {
// dp 转移
if (dp[u] + w[v] > dp[v]) {
dp[v] = dp[u] + w[v];
}
}
}
return dp[n] >= 0;
};

// 排序去重
vector<int> vals = a;
// +1 下标从 1 开始
sort(vals.begin() + 1, vals.end());
vals.erase(unique(vals.begin() + 1, vals.end()), vals.end());

int left = 1;
int right = vals.size() - 1;
while (left <= right) {
int mid = left + ((right - left) >> 1);
if (check(vals[mid])) left = mid + 1;
else right = mid - 1;
}

cout << (right > 0 ? vals[right] : -1);
return 0;
}

Python

超时但逻辑正确版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
在一个路径中,将大于等于中位数的记为 +1,小于中位数的记为 -1,那么路径上的总和一定是大于等于 0 的

给定一个 x,将大于等于 x 的记为 +1,小于 x 的记为 -1,
那么路径上的总和
大于 0,x 小于等于中位数
等于 0,x 就是中位数
小于 0,x 大于中位数

题目问的是,在所有 1 到 n 的路径中,中位数最大的是多少?
转换问题为,假设最大的中位数是 x,是否存在一条 1 到 n 的路径的中位数至少是 x?
用上面的方法计算最大的每条路径总和,如果这个总和 >= 0,说明 x <= 中位数,
即存在一条路径的中位数 >= x,为了找到答案要的「最大」的中位数,要增大 x。
"""

from collections import deque

NEG_INF = -10 ** 18


def check(x: int) -> bool:
dp = [NEG_INF] * (n + 1)
ver = [1 if v >= x else -1 for v in vertices]
dp[1] = 1 if vertices[1] >= x else -1
for u in topo_list:
if dp[u] == NEG_INF:
continue
for v in g[u]:
dp[v] = max(dp[v], dp[u] + ver[v])
return dp[n] >= 0


n, m = list(map(int, input().split())) # n 节点数,m 边数量
vertices = [NEG_INF] + list(map(int, input().split()))
g = [[] for _ in range(n + 1)]
in_deg = [0] * (n + 1)
in_deg[0] = NEG_INF
for _ in range(m):
x, y = list(map(int, input().split()))
g[x].append(y)
in_deg[y] += 1

# 拓扑排序 ################
topo_list = []
q = deque([i for i, d in enumerate(in_deg) if d == 0])
while q:
x = q.popleft()
topo_list.append(x)
for y in g[x]:
in_deg[y] -= 1
if in_deg[y] == 0:
q.append(y)

left, right = 0, max(vertices)
while left <= right:
mid = (left + right) // 2
if check(mid):
left = mid + 1
else:
right = mid - 1

print(right)

优化可通过版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
from collections import deque
from array import array
NEG_INF = -10 ** 9


def get_next_int():
global pos
while pos < n_data and data[pos] <= 32:
pos += 1
if pos >= n_data:
return None
res = 0
while pos < n_data and data[pos] > 32:
res = res * 10 + (data[pos] - 48)
pos += 1
return res

data = sys.stdin.buffer.read()
n_data = len(data)
pos = 0
n, m = get_next_int(), get_next_int() # n 节点数,m 边数量
a = array('i', [0]) * (n + 1)
for i in range(1, n + 1):
a[i] = get_next_int()

g = [array('i') for _ in range(n + 1)]
in_deg = array('i', [0]) * (n + 1)
for _ in range(m):
x, y = get_next_int(), get_next_int()
g[x].append(y)
in_deg[y] += 1

# 拓扑排序 ################
topo_list = []
q = deque([i for i in range(1, n + 1) if in_deg[i] == 0])
while q:
x = q.popleft()
topo_list.append(x)
for y in g[x]:
in_deg[y] -= 1
if in_deg[y] == 0:
q.append(y)
dp = array('i', [NEG_INF]) * (n + 1)
dp_reset = array('i', [NEG_INF] ) * (n + 1) # 用于快速重置
def check(x: int) -> bool:
global dp
dp[:] = dp_reset
dp[1] = 1 if a[1] >= x else -1
for u in topo_list:
if dp[u] == NEG_INF:
continue
for v in g[u]:
w = 1 if a[v] >= x else -1
dp[v] = max(dp[v], dp[u] + w)
return dp[n] >= 0

left, right = 0, max(a)
while left <= right:
mid = (left + right) // 2
if check(mid):
left = mid + 1
else:
right = mid - 1

print(right)

优化点(具体的说明和其他 Python 的优化方法见这篇 ):

  1. 直接在原始 buffer 读取数据,避免 split() 产生临时的内存消耗;
  2. 采用 array.array

做到这两点,就能擦线过了,但还有可能超时,所以更极致的优化的策略有:

  1. 手写 max
  2. 将变量都装进 solve 函数中,利用访问局部变量来提速;
  3. 离散化二分,同 C++ 的二分优化策略;
  4. 变量本地化:在 check 函数里先做变量的引用;
  5. 数据结构优化,采用 链式前向星CSR 代替邻接表存储图(我还不会)。

这些都是理论上的优化策略,但有些可能作用有限,且组合在一起甚至可能使得已经过的代码因为一些玄学的原因反而会超时或超内存。

所以我觉得学会或者了解这些优化策略已经非常够了,再死磕意义不大了。

最后贴一下表现最优秀的 PyPy 代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import sys
from array import array

# CSR 构建逻辑已修正
def solve():
# 1. 快速字节流读取
try:
raw_data = sys.stdin.buffer.read()
except Exception:
return

if not raw_data:
return

# 迭代器解析整数
def get_ints():
curr = 0
n_data = len(raw_data)
while curr < n_data:
# 跳过空白字符
while curr < n_data and raw_data[curr] <= 32:
curr += 1
if curr >= n_data: break

# 解析数字 (包括负号,虽然题目说 Ai>=0 但通用性更好)
start = curr
if raw_data[curr] == 45: # '-'
curr += 1

while curr < n_data and 48 <= raw_data[curr] <= 57:
curr += 1

yield int(raw_data[start:curr])

gen = get_ints()

try:
n = next(gen)
m = next(gen)
except StopIteration:
return

# 2. 存储点权
a = array('i', [0] * (n + 1))
for i in range(1, n + 1):
a[i] = next(gen)

# 3. CSR 图构建 (修正版)
# 使用 edges_raw 暂存边信息
edges_raw = array('i', [0] * (m * 2))
indeg = array('i', [0] * (n + 1))

for i in range(0, 2 * m, 2):
u = next(gen)
v = next(gen)
edges_raw[i] = u
edges_raw[i+1] = v
indeg[v] += 1

# 清理不再需要的 raw_data
del raw_data

# 计算 Head 数组 (CSR 索引)
head = array('i', [0] * (n + 2))

# Step A: 统计每个点的出度
for i in range(0, 2 * m, 2):
u = edges_raw[i]
head[u] += 1

# Step B: 原地转换为起始索引 (Prefix Sum)
# head[i] 将存储节点 i 的第一条边在 edges 数组中的位置
current_offset = 0
for i in range(1, n + 2):
degree = head[i]
head[i] = current_offset
current_offset += degree

# 此时 head[n+1] 应该等于 m

# Step C: 填充 edges 数组
edges = array('i', [0] * m)
# 我们需要一个临时数组来跟踪填充进度,或者直接复用 head
# 为了节省内存,我们拷贝一份 head 作为 "当前写入位置" 指针
cur_head = array('i', head)

for i in range(0, 2 * m, 2):
u = edges_raw[i]
v = edges_raw[i+1]
pos = cur_head[u]
edges[pos] = v
cur_head[u] += 1

# 清理临时数组
del edges_raw
del cur_head

# 4. 拓扑排序
topo = array('i', [0] * n)
queue = array('i', [x for x in range(1, n + 1) if indeg[x] == 0])

# 使用指针模拟队列,避免 pop(0) 的开销
q_read_ptr = 0
t_ptr = 0

while q_read_ptr < len(queue):
u = queue[q_read_ptr]
q_read_ptr += 1

topo[t_ptr] = u
t_ptr += 1

# 遍历 u 的所有邻居
start_edge = head[u]
end_edge = head[u+1]

for i in range(start_edge, end_edge):
v = edges[i]
indeg[v] -= 1
if indeg[v] == 0:
queue.append(v)

del indeg
del queue

# 5. 二分查找与 DP
# 离散化权值
vals = sorted(list(set(a[1:])))
num_vals = len(vals)

dp = array('i', [0] * (n + 1))
weights = array('b', [0] * (n + 1)) # 使用 1 字节整型
INF = 10**9

def check(x):
# 快速重置 DP 数组为 -INF
# array 的切片赋值比循环快
# 但在 Pypy 下循环赋值通常优化得很好,为了省内存不创建新对象:
for i in range(n + 1):
dp[i] = -INF

# 预计算权重
for i in range(1, n + 1):
weights[i] = 1 if a[i] >= x else -1

# 起点 Check
dp[1] = weights[1]

# 局部变量缓存,减少属性查找开销
_head = head
_edges = edges
_dp = dp
_w = weights
_topo = topo

for u in _topo:
curr_val = _dp[u]
if curr_val <= -INF:
continue

# CSR 遍历
start = _head[u]
end = _head[u+1]

for i in range(start, end):
v = _edges[i]
# 转移方程:dp[v] = max(dp[v], dp[u] + w[v])
next_val = curr_val + _w[v]
if next_val > _dp[v]:
_dp[v] = next_val

return _dp[n] >= 0

ans = -1
low = 0
high = num_vals - 1

while low <= high:
mid = (low + high) // 2
val = vals[mid]
if check(val):
ans = val
low = mid + 1
else:
high = mid - 1

sys.stdout.write(str(ans) + '\n')

if __name__ == '__main__':
solve()

其他

关联一下 EOJ 2026. Telephone Lines 这道题。为什么关联在那边的最后有说。