for(int i = 1; i <= n; i++) cin >> a[i]; for(int i = 0; i < m; i++){ int u, v; cin >> u >> v; g[u].emplace_back(v); indeg[v] += 1; }
// 拓扑排序 deque<int> q; for(int i = 1; i <= n; i++){ if(indeg[i] == 0) q.emplace_back(i); } while(!q.empty()){ int u = q.front(); q.pop_front(); topo.emplace_back(u); for(int v: g[u]){ indeg[v]--; if(indeg[v] == 0) q.emplace_back(v); } }
int left = *min_element(a.begin(), a.end()); int right = *max_element(a.begin(), a.end()); while (left <= right){ int mid = left + ((right - left) >> 1); if (check(mid)) left = mid + 1; else right = mid - 1; } cout << right; return0; }
优化版
优化点(主要是前两个):
输入。数量级大了,提速明显。
1 2
ios::sync_with_stdio(false); cin.tie(nullptr);`
原因:
sync_with_stdio(false) 砍掉了 C++ 与 C 语言输入输出的 兼容同步,让 cin 跑得飞快。
int n, m; cin >> n >> m; vector<int> a(n + 1); vector<vector<int>> g(n + 1); vector<int> indeg(n + 1);
for (int i = 1; i <= n; i++) cin >> a[i]; for (int i = 0; i < m; i++) { int u, v; cin >> u >> v; g[u].emplace_back(v); indeg[v] += 1; }
// 拓扑排序 deque<int> q; vector<int> topo; topo.reserve(n + 1); for (int i = 1; i <= n; i++) { if (indeg[i] == 0) q.emplace_back(i); } while (!q.empty()) { int u = q.front(); q.pop_front(); topo.emplace_back(u); for (int v : g[u]) { indeg[v]--; if (indeg[v] == 0) q.emplace_back(v); } }
vector<int> dp(n + 1); vector<int> w(n + 1); auto check = [&](int m) { // 权值转换 for (int i = 1; i <= n; i++) w[i] = a[i] >= m ? 1 : -1; // 初始化 dp fill(dp.begin(), dp.end(), NEG_INF); dp[1] = w[1];
for (int u : topo) { // 最好有这个 continue if (dp[u] == NEG_INF) continue; for (int v : g[u]) { // dp 转移 if (dp[u] + w[v] > dp[v]) { dp[v] = dp[u] + w[v]; } } } return dp[n] >= 0; };
int left = 1; int right = vals.size() - 1; while (left <= right) { int mid = left + ((right - left) >> 1); if (check(vals[mid])) left = mid + 1; else right = mid - 1; }
""" 在一个路径中,将大于等于中位数的记为 +1,小于中位数的记为 -1,那么路径上的总和一定是大于等于 0 的 给定一个 x,将大于等于 x 的记为 +1,小于 x 的记为 -1, 那么路径上的总和 大于 0,x 小于等于中位数 等于 0,x 就是中位数 小于 0,x 大于中位数 题目问的是,在所有 1 到 n 的路径中,中位数最大的是多少? 转换问题为,假设最大的中位数是 x,是否存在一条 1 到 n 的路径的中位数至少是 x? 用上面的方法计算最大的每条路径总和,如果这个总和 >= 0,说明 x <= 中位数, 即存在一条路径的中位数 >= x,为了找到答案要的「最大」的中位数,要增大 x。 """
from collections import deque
NEG_INF = -10 ** 18
defcheck(x: int) -> bool: dp = [NEG_INF] * (n + 1) ver = [1if v >= x else -1for v in vertices] dp[1] = 1if vertices[1] >= x else -1 for u in topo_list: if dp[u] == NEG_INF: continue for v in g[u]: dp[v] = max(dp[v], dp[u] + ver[v]) return dp[n] >= 0
n, m = list(map(int, input().split())) # n 节点数,m 边数量 vertices = [NEG_INF] + list(map(int, input().split())) g = [[] for _ inrange(n + 1)] in_deg = [0] * (n + 1) in_deg[0] = NEG_INF for _ inrange(m): x, y = list(map(int, input().split())) g[x].append(y) in_deg[y] += 1
# 拓扑排序 ################ topo_list = [] q = deque([i for i, d inenumerate(in_deg) if d == 0]) while q: x = q.popleft() topo_list.append(x) for y in g[x]: in_deg[y] -= 1 if in_deg[y] == 0: q.append(y)
left, right = 0, max(vertices) while left <= right: mid = (left + right) // 2 if check(mid): left = mid + 1 else: right = mid - 1
import sys from collections import deque from array import array NEG_INF = -10 ** 9
defget_next_int(): global pos while pos < n_data and data[pos] <= 32: pos += 1 if pos >= n_data: returnNone res = 0 while pos < n_data and data[pos] > 32: res = res * 10 + (data[pos] - 48) pos += 1 return res
data = sys.stdin.buffer.read() n_data = len(data) pos = 0 n, m = get_next_int(), get_next_int() # n 节点数,m 边数量 a = array('i', [0]) * (n + 1) for i inrange(1, n + 1): a[i] = get_next_int()
g = [array('i') for _ inrange(n + 1)] in_deg = array('i', [0]) * (n + 1) for _ inrange(m): x, y = get_next_int(), get_next_int() g[x].append(y) in_deg[y] += 1
# 拓扑排序 ################ topo_list = [] q = deque([i for i inrange(1, n + 1) if in_deg[i] == 0]) while q: x = q.popleft() topo_list.append(x) for y in g[x]: in_deg[y] -= 1 if in_deg[y] == 0: q.append(y) dp = array('i', [NEG_INF]) * (n + 1) dp_reset = array('i', [NEG_INF] ) * (n + 1) # 用于快速重置 defcheck(x: int) -> bool: global dp dp[:] = dp_reset dp[1] = 1if a[1] >= x else -1 for u in topo_list: if dp[u] == NEG_INF: continue for v in g[u]: w = 1if a[v] >= x else -1 dp[v] = max(dp[v], dp[u] + w) return dp[n] >= 0
left, right = 0, max(a) while left <= right: mid = (left + right) // 2 if check(mid): left = mid + 1 else: right = mid - 1