1use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7use syn::parse_quote;
8
9use super::meta_graph::DfirGraph;
10use super::ops::{DelayType, FloType};
11use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
12use crate::diagnostic::{Diagnostic, Level};
13use crate::union_find::UnionFind;
14
15struct BarrierCrossers {
17 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
19 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
21}
22impl BarrierCrossers {
23 fn iter_node_pairs<'a>(
25 &'a self,
26 partitioned_graph: &'a DfirGraph,
27 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
28 let edge_pairs_iter = self
29 .edge_barrier_crossers
30 .iter()
31 .map(|(edge_id, &delay_type)| {
32 let src_dst = partitioned_graph.edge(edge_id);
33 (src_dst, delay_type)
34 });
35 let singleton_pairs_iter = self
36 .singleton_barrier_crossers
37 .iter()
38 .map(|&src_dst| (src_dst, DelayType::Stratum));
39 edge_pairs_iter.chain(singleton_pairs_iter)
40 }
41
42 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
44 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
45 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
46 }
47 }
48}
49
50fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
52 let edge_barrier_crossers = partitioned_graph
53 .edges()
54 .filter(|&(_, (_src, dst))| {
55 partitioned_graph.node_loop(dst).is_none()
57 })
58 .filter_map(|(edge_id, (_src, dst))| {
59 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
60 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
61 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
62 Some((edge_id, input_barrier))
63 })
64 .collect();
65 let singleton_barrier_crossers = partitioned_graph
66 .node_ids()
67 .flat_map(|dst| {
68 partitioned_graph
69 .node_singleton_references(dst)
70 .iter()
71 .flatten()
72 .map(move |&src_ref| (src_ref, dst))
73 })
74 .collect();
75 BarrierCrossers {
76 edge_barrier_crossers,
77 singleton_barrier_crossers,
78 }
79}
80
81fn find_subgraph_unionfind(
82 partitioned_graph: &DfirGraph,
83 barrier_crossers: &BarrierCrossers,
84) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
85 let mut node_color = partitioned_graph
90 .node_ids()
91 .filter_map(|node_id| {
92 let op_color = partitioned_graph.node_color(node_id)?;
93 Some((node_id, op_color))
94 })
95 .collect::<SparseSecondaryMap<_, _>>();
96
97 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
98 UnionFind::with_capacity(partitioned_graph.nodes().len());
99
100 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
103 let mut progress = true;
112 while progress {
113 progress = false;
114 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
116 if subgraph_unionfind.same_set(src, dst) {
118 continue;
121 }
122
123 if barrier_crossers
125 .iter_node_pairs(partitioned_graph)
126 .any(|((x_src, x_dst), _)| {
127 (subgraph_unionfind.same_set(x_src, src)
128 && subgraph_unionfind.same_set(x_dst, dst))
129 || (subgraph_unionfind.same_set(x_src, dst)
130 && subgraph_unionfind.same_set(x_dst, src))
131 })
132 {
133 continue;
134 }
135
136 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
138 continue;
139 }
140 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
142 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
143 }) {
144 continue;
145 }
146
147 if can_connect_colorize(&mut node_color, src, dst) {
148 subgraph_unionfind.union(src, dst);
151 assert!(handoff_edges.remove(&edge_id));
152 progress = true;
153 }
154 }
155 }
156
157 (subgraph_unionfind, handoff_edges)
158}
159
160fn make_subgraph_collect(
164 partitioned_graph: &DfirGraph,
165 mut subgraph_unionfind: UnionFind<GraphNodeId>,
166) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
167 let topo_sort = graph_algorithms::topo_sort(
171 partitioned_graph
172 .nodes()
173 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
174 .map(|(node_id, _)| node_id),
175 |v| {
176 partitioned_graph
177 .node_predecessor_nodes(v)
178 .filter(|&pred_id| {
179 let pred = partitioned_graph.node(pred_id);
180 !matches!(pred, GraphNode::Handoff { .. })
181 })
182 },
183 )
184 .expect("Subgraphs are in-out trees.");
185
186 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
187 for node_id in topo_sort {
188 let repr_node = subgraph_unionfind.find(node_id);
189 if !grouped_nodes.contains_key(repr_node) {
190 grouped_nodes.insert(repr_node, Default::default());
191 }
192 grouped_nodes[repr_node].push(node_id);
193 }
194 grouped_nodes
195}
196
197fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
201 let (subgraph_unionfind, handoff_edges) =
210 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
211
212 for edge_id in handoff_edges {
214 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
215
216 let src_node = partitioned_graph.node(src_id);
218 let dst_node = partitioned_graph.node(dst_id);
219 if matches!(src_node, GraphNode::Handoff { .. })
220 || matches!(dst_node, GraphNode::Handoff { .. })
221 {
222 continue;
223 }
224
225 let hoff = GraphNode::Handoff {
226 src_span: src_node.span(),
227 dst_span: dst_node.span(),
228 };
229 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
230
231 barrier_crossers.replace_edge(edge_id, out_edge_id);
233 }
234
235 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
239 for (_repr_node, member_nodes) in grouped_nodes {
240 partitioned_graph.insert_subgraph(member_nodes).unwrap();
241 }
242}
243
244fn can_connect_colorize(
250 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
251 src: GraphNodeId,
252 dst: GraphNodeId,
253) -> bool {
254 let can_connect = match (node_color.get(src), node_color.get(dst)) {
259 (None, None) => false,
262
263 (None, Some(Color::Pull | Color::Comp)) => {
265 node_color.insert(src, Color::Pull);
266 true
267 }
268 (None, Some(Color::Push | Color::Hoff)) => {
269 node_color.insert(src, Color::Push);
270 true
271 }
272
273 (Some(Color::Pull | Color::Hoff), None) => {
275 node_color.insert(dst, Color::Pull);
276 true
277 }
278 (Some(Color::Comp | Color::Push), None) => {
279 node_color.insert(dst, Color::Push);
280 true
281 }
282
283 (Some(Color::Pull), Some(Color::Pull)) => true,
285 (Some(Color::Pull), Some(Color::Comp)) => true,
286 (Some(Color::Pull), Some(Color::Push)) => true,
287
288 (Some(Color::Comp), Some(Color::Pull)) => false,
289 (Some(Color::Comp), Some(Color::Comp)) => false,
290 (Some(Color::Comp), Some(Color::Push)) => true,
291
292 (Some(Color::Push), Some(Color::Pull)) => false,
293 (Some(Color::Push), Some(Color::Comp)) => false,
294 (Some(Color::Push), Some(Color::Push)) => true,
295
296 (Some(Color::Hoff), Some(_)) => false,
298 (Some(_), Some(Color::Hoff)) => false,
299 };
300 can_connect
301}
302
303fn order_subgraphs(
309 partitioned_graph: &mut DfirGraph,
310 barrier_crossers: &BarrierCrossers,
311) -> Result<(), Diagnostic> {
312 let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
314
315 let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
317
318 for (node_id, node) in partitioned_graph.nodes() {
320 if !matches!(node, GraphNode::Handoff { .. }) {
321 continue;
322 }
323 assert_eq!(1, partitioned_graph.node_successors(node_id).len());
324 let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
325
326 let succ_edge_delaytype = barrier_crossers
327 .edge_barrier_crossers
328 .get(succ_edge)
329 .copied();
330 if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
332 tick_edges.push((succ_edge, delay_type));
333 continue;
334 }
335
336 assert_eq!(1, partitioned_graph.node_predecessors(node_id).len());
337 let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
338
339 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
340 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
341
342 sg_preds.entry(succ_sg).or_default().push(pred_sg);
343 }
344 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
346 assert_ne!(pred, succ, "TODO(mingwei)");
347 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
348 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
349 assert_ne!(pred_sg, succ_sg);
350 sg_preds.entry(succ_sg).or_default().push(pred_sg);
351 }
352
353 let topo_sort_order = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
355 sg_preds.get(&v).into_iter().flatten().copied()
356 });
357 let topo_sort_order = match topo_sort_order {
358 Ok(order) => order,
359 Err(cycle) => {
360 let span = cycle
361 .first()
362 .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
363 .map(|n| partitioned_graph.node(n).span())
364 .unwrap_or_else(Span::call_site);
365 return Err(Diagnostic::spanned(
366 span,
367 Level::Error,
368 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
369 ));
370 }
371 };
372
373 let sg_position: BTreeMap<GraphSubgraphId, usize> = topo_sort_order
375 .iter()
376 .enumerate()
377 .map(|(i, &sg_id)| (sg_id, i))
378 .collect();
379
380 for (edge_id, delay_type) in tick_edges {
384 let (hoff, dst) = partitioned_graph.edge(edge_id);
385 if partitioned_graph.node_loop(dst).is_some() {
387 continue;
388 }
389
390 assert_eq!(1, partitioned_graph.node_predecessors(hoff).len());
391 let src = partitioned_graph
392 .node_predecessor_nodes(hoff)
393 .next()
394 .unwrap();
395
396 let src_sg = partitioned_graph.node_subgraph(src).unwrap();
397 let dst_sg = partitioned_graph.node_subgraph(dst).unwrap();
398 let src_pos = sg_position[&src_sg];
399 let dst_pos = sg_position[&dst_sg];
400 let dst_span = partitioned_graph.node(dst).span();
401
402 if src_pos <= dst_pos {
404 let (new_node_id, new_edge_id) = partitioned_graph.insert_intermediate_node(
407 edge_id,
408 GraphNode::Operator(parse_quote! { identity() }),
410 );
411 let hoff_node = GraphNode::Handoff {
413 src_span: dst_span,
414 dst_span,
415 };
416 let (hoff_node_id, _hoff_edge_id) =
417 partitioned_graph.insert_intermediate_node(new_edge_id, hoff_node);
418 partitioned_graph
422 .insert_subgraph(vec![new_node_id])
423 .unwrap();
424
425 partitioned_graph.set_handoff_delay_type(hoff_node_id, delay_type);
427 } else {
428 partitioned_graph.set_handoff_delay_type(hoff, delay_type);
431 }
432 }
433 Ok(())
434}
435
436pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
440 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
442 let mut partitioned_graph = flat_graph;
443
444 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
446
447 order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
449
450 Ok(partitioned_graph)
451}