0%

EOJ 3462. 最小OR路径

题目链接:https://acm.ecnu.edu.cn/problem/3462/

也是一道执着地优化 Python 的一道题。思路并不太难,试填法加连通性判断就行,连通性判断可以 BFS 、并查集等好多方法了。

主要记录一些优化 Python 的点,因为同样的思路,C++ 就能轻松过而 Python 就是不行。

输入

生成器不如 ptr 指针读。

生成器法代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def get_num():
data = sys.stdin.buffer.read()
num = 0
found = False
for c in data:
if 48 <= c <= 57:
num = num * 10 + c - 48
found = True
else:
yield num
num = 0
found = False
if found:
yield num
data = get_num()
n = next(data)
m = next(data)
...

这个方法,不如

1
2
3
4
data = sys.stdin.buffer.read().split()
n, m = int(data[0]), int(data[1])
ptr = 2
...

split() 会造成内存的消耗。如果没有超内存就好,超了的话,就用下面的方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def get_next_int():
global pos
while pos < n_data and data[pos] <= 32:
pos += 1
if pos >= n_data:
return None
res = 0
while pos < n_data and data[pos] > 32:
res = res * 10 + (data[pos] - 48)
pos += 1
return res

data = sys.stdin.buffer.read()
n_data = len(data)
pos = 0
n, m = get_next_int(), get_next_int()
...

时间戳法记 vis

用时间戳的方式来代替布尔类型的 vis 数组,避免了反复重开赋值的开销。

具体看后面的代码。

位运算判断逻辑

正向判断:

假设 $ans$ 第 $i$ 位是 $0$,要求路径上的权值的第 $i$ 位也是 $0$ ,如果能找到这样的路径从 $a$ 到 $b$ ,那么第 $i$ 位就是 $0$ ,否则填 $1$ 。并且,在试填低位的时候,之前的高位为 $0$ 的限制也必须保留。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
ans = 0
for i in range(62, -1, -1):
mask = ans >> i
# ...
while q:
u = q.popleft()
# if u == b...
for u, w in edges[u]:
if not vis[v] and ((w >> i) | mask) == mask:
# ...
if not flag: # 不连通,填 1
ans |= 1 << i

反向判断:

假设 $mask$ 第 $i$ 位是 $1$,专找 w & mask == 0 的路径,如果能找到这样的路径从 $a$ 到 $b$ ,那么 $ans$ 的第 $i$ 位就是 $0$ ,否则填 $1$ 。但为了在后面低位的时候也能限制之前的高位,所以试填的过程中先反着填,最后再取反。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ans = 0
for i in range(mx_bit_length + 1, -1, -1):
# ...
mask = ans | (1 << i)
flag = False
while q:
u = q.popleft()
# if u == b:...
for v, w in edges[u]:
if vis[v] != time_stamp and (w & mask) == 0:
# ...
if flag:
ans |= 1 << i
# ...

print((ans ^ ((1 << (mx_bit_length + 2)) - 1)) if ans != -1 else -1)

核心的一个是 ((w >> i) | mask) == mask ,一个是 (w & mask) == 0 ,后者开销小。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
from collections import deque


def main():
data = sys.stdin.buffer.read().split()
n, m = int(data[0]), int(data[1])
ptr = 2

mx_bit_length = 0
edges = [[] for _ in range(n + 1)]
for _ in range(m):
u, v, w = int(data[ptr]), int(data[ptr + 1]), int(data[ptr + 2])
ptr += 3

if u == v:
continue
edges[u].append((v, w))
edges[v].append((u, w))
if mx_bit_length < w.bit_length():
mx_bit_length = w.bit_length()

a, b = int(data[-2]), int(data[-1])

vis = [0] * (n + 1)
time_stamp = 0
q = deque()
ans = 0
for i in range(mx_bit_length + 1, -1, -1):
time_stamp += 1
vis[a] = time_stamp
q.clear()
q.append(a)
mask = ans | (1 << i)
flag = False
while q:
u = q.popleft()
if u == b:
flag = True
break
for v, w in edges[u]:
if vis[v] != time_stamp and (w & mask) == 0:
q.append(v)
vis[v] = time_stamp
if flag:
ans |= 1 << i
elif i == mx_bit_length + 1:
ans = -1
break

print((ans ^ ((1 << (mx_bit_length + 2)) - 1)) if ans != -1 else -1)


if __name__ == '__main__':
main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <bits/stdc++.h>
using namespace std;

struct edge {
int to;
long long wt;

bool operator<(const edge &o) const {
if (o.to == to) return to < o.to;
return wt < o.wt;
}

bool operator==(const edge &o) const {
return to == o.to;
}
};

int main() {
ios::sync_with_stdio(false);
cin.tie(0);

int n, m;
cin >> n >> m;
vector<vector<edge>> edges(n + 1);
while (m--) {
int u, v;
cin >> u >> v;
long long w;
cin >> w;
if (u == v) continue;
edges[u].push_back({v, w});
edges[v].push_back({u, w});
}

// 没有变快
for (int i = 1; i < edges.size(); i++) {
auto &e = edges[i];
sort(e.begin(), e.end());
e.erase(unique(e.begin(), e.end()), e.end());
}

int a, b;
cin >> a >> b;

vector<bool> vis(n + 1);

long long ans = 0;
for (int i = 62; i >= 0; i--) {
long long mask = ans >> i;
bool flag = false;
fill(vis.begin(), vis.end(), false);
vis[a] = true;
deque<int> q;
q.push_back(a);
while (!q.empty()) {
int u = q.front();
if (u == b) {
flag = true;
break;
}
q.pop_front();
for (edge &e : edges[u]) {
int v = e.to;
long long w = e.wt;
if (!vis[v] && ((w >> i) | mask) == mask) {
vis[v] = true;
q.push_back(v);
}
}
}
if (!flag) {
if (i == 62) {
ans = -1;
break;
}
ans |= 1LL << i;
}
}
cout << ans;
return 0;
}