Skip to main content

dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7use syn::parse_quote;
8
9use super::meta_graph::DfirGraph;
10use super::ops::{DelayType, FloType};
11use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
12use crate::diagnostic::{Diagnostic, Level};
13use crate::union_find::UnionFind;
14
15/// Helper struct for tracking barrier crossers, see [`find_barrier_crossers`].
16struct BarrierCrossers {
17    /// Edge barrier crossers, including what type.
18    pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
19    /// Singleton reference barrier crossers, considered to be [`DelayType::Stratum`].
20    pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
21}
22impl BarrierCrossers {
23    /// Iterate pairs of nodes that are across a barrier. Excludes `DelayType::NextIteration` pairs.
24    fn iter_node_pairs<'a>(
25        &'a self,
26        partitioned_graph: &'a DfirGraph,
27    ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
28        let edge_pairs_iter = self
29            .edge_barrier_crossers
30            .iter()
31            .map(|(edge_id, &delay_type)| {
32                let src_dst = partitioned_graph.edge(edge_id);
33                (src_dst, delay_type)
34            });
35        let singleton_pairs_iter = self
36            .singleton_barrier_crossers
37            .iter()
38            .map(|&src_dst| (src_dst, DelayType::Stratum));
39        edge_pairs_iter.chain(singleton_pairs_iter)
40    }
41
42    /// Insert/replace edge.
43    fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
44        if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
45            self.edge_barrier_crossers.insert(new_edge_id, delay_type);
46        }
47    }
48}
49
50/// Find all the barrier crossers.
51fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
52    let edge_barrier_crossers = partitioned_graph
53        .edges()
54        .filter(|&(_, (_src, dst))| {
55            // Ignore barriers within `loop {` blocks.
56            partitioned_graph.node_loop(dst).is_none()
57        })
58        .filter_map(|(edge_id, (_src, dst))| {
59            let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
60            let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
61            let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
62            Some((edge_id, input_barrier))
63        })
64        .collect();
65    let singleton_barrier_crossers = partitioned_graph
66        .node_ids()
67        .flat_map(|dst| {
68            partitioned_graph
69                .node_singleton_references(dst)
70                .iter()
71                .flatten()
72                .map(move |&src_ref| (src_ref, dst))
73        })
74        .collect();
75    BarrierCrossers {
76        edge_barrier_crossers,
77        singleton_barrier_crossers,
78    }
79}
80
81fn find_subgraph_unionfind(
82    partitioned_graph: &DfirGraph,
83    barrier_crossers: &BarrierCrossers,
84) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
85    // Modality (color) of nodes, push or pull.
86    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
87    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
88    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
89    let mut node_color = partitioned_graph
90        .node_ids()
91        .filter_map(|node_id| {
92            let op_color = partitioned_graph.node_color(node_id)?;
93            Some((node_id, op_color))
94        })
95        .collect::<SparseSecondaryMap<_, _>>();
96
97    let mut subgraph_unionfind: UnionFind<GraphNodeId> =
98        UnionFind::with_capacity(partitioned_graph.nodes().len());
99
100    // Will contain all edges which are handoffs. Starts out with all edges and
101    // we remove from this set as we combine nodes into subgraphs.
102    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
103    // Would sort edges here for priority (for now, no sort/priority).
104
105    // Each edge gets looked at in order. However we may not know if a linear
106    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
107    // algorithm would know to handle linear chains from the outside inward.
108    // But instead we just run through the edges in a loop until no more
109    // progress is made. Could have some sort of O(N^2) pathological worst
110    // case.
111    let mut progress = true;
112    while progress {
113        progress = false;
114        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
115        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
116            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
117            if subgraph_unionfind.same_set(src, dst) {
118                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
119                // Handoffs will be inserted later for this self-loop.
120                continue;
121            }
122
123            // Do not connect stratum crossers (next edges).
124            if barrier_crossers
125                .iter_node_pairs(partitioned_graph)
126                .any(|((x_src, x_dst), _)| {
127                    (subgraph_unionfind.same_set(x_src, src)
128                        && subgraph_unionfind.same_set(x_dst, dst))
129                        || (subgraph_unionfind.same_set(x_src, dst)
130                            && subgraph_unionfind.same_set(x_dst, src))
131                })
132            {
133                continue;
134            }
135
136            // Do not connect across loop contexts.
137            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
138                continue;
139            }
140            // Do not connect `next_iteration()`.
141            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
142                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
143            }) {
144                continue;
145            }
146
147            if can_connect_colorize(&mut node_color, src, dst) {
148                // At this point we have selected this edge and its src & dst to be
149                // within a single subgraph.
150                subgraph_unionfind.union(src, dst);
151                assert!(handoff_edges.remove(&edge_id));
152                progress = true;
153            }
154        }
155    }
156
157    (subgraph_unionfind, handoff_edges)
158}
159
160/// Builds the datastructures for checking which subgraph each node belongs to
161/// after handoffs have already been inserted to partition subgraphs.
162/// This list of nodes in each subgraph are returned in topological sort order.
163fn make_subgraph_collect(
164    partitioned_graph: &DfirGraph,
165    mut subgraph_unionfind: UnionFind<GraphNodeId>,
166) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
167    // We want the nodes of each subgraph to be listed in topo-sort order.
168    // We could do this on each subgraph, or we could do it all at once on the
169    // whole node graph by ignoring handoffs, which is what we do here:
170    let topo_sort = graph_algorithms::topo_sort(
171        partitioned_graph
172            .nodes()
173            .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
174            .map(|(node_id, _)| node_id),
175        |v| {
176            partitioned_graph
177                .node_predecessor_nodes(v)
178                .filter(|&pred_id| {
179                    let pred = partitioned_graph.node(pred_id);
180                    !matches!(pred, GraphNode::Handoff { .. })
181                })
182        },
183    )
184    .expect("Subgraphs are in-out trees.");
185
186    let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
187    for node_id in topo_sort {
188        let repr_node = subgraph_unionfind.find(node_id);
189        if !grouped_nodes.contains_key(repr_node) {
190            grouped_nodes.insert(repr_node, Default::default());
191        }
192        grouped_nodes[repr_node].push(node_id);
193    }
194    grouped_nodes
195}
196
197/// Find subgraph and insert handoffs.
198/// Modifies barrier_crossers so that the edge OUT of an inserted handoff has
199/// the DelayType data.
200fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
201    // Algorithm:
202    // 1. Each node begins as its own subgraph.
203    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
204    // 3. For each edge, try to join `(to, from)` into the same subgraph.
205
206    // TODO(mingwei):
207    // self.partitioned_graph.assert_valid();
208
209    let (subgraph_unionfind, handoff_edges) =
210        find_subgraph_unionfind(partitioned_graph, barrier_crossers);
211
212    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
213    for edge_id in handoff_edges {
214        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
215
216        // Already has a handoff, no need to insert one.
217        let src_node = partitioned_graph.node(src_id);
218        let dst_node = partitioned_graph.node(dst_id);
219        if matches!(src_node, GraphNode::Handoff { .. })
220            || matches!(dst_node, GraphNode::Handoff { .. })
221        {
222            continue;
223        }
224
225        let hoff = GraphNode::Handoff {
226            src_span: src_node.span(),
227            dst_span: dst_node.span(),
228        };
229        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
230
231        // Update barrier_crossers for inserted node.
232        barrier_crossers.replace_edge(edge_id, out_edge_id);
233    }
234
235    // Determine node's subgraph and subgraph's nodes.
236    // This list of nodes in each subgraph are to be in topological sort order.
237    // Eventually returned directly in the [`DfirGraph`].
238    let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
239    for (_repr_node, member_nodes) in grouped_nodes {
240        partitioned_graph.insert_subgraph(member_nodes).unwrap();
241    }
242}
243
244/// Set `src` or `dst` color if `None` based on the other (if possible):
245/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
246/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
247///
248/// Returns if `src` and `dst` can be in the same subgraph.
249fn can_connect_colorize(
250    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
251    src: GraphNodeId,
252    dst: GraphNodeId,
253) -> bool {
254    // Pull -> Pull
255    // Push -> Push
256    // Pull -> [Computation] -> Push
257    // Push -> [Handoff] -> Pull
258    let can_connect = match (node_color.get(src), node_color.get(dst)) {
259        // Linear chain, can't connect because it may cause future conflicts.
260        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
261        (None, None) => false,
262
263        // Infer left side.
264        (None, Some(Color::Pull | Color::Comp)) => {
265            node_color.insert(src, Color::Pull);
266            true
267        }
268        (None, Some(Color::Push | Color::Hoff)) => {
269            node_color.insert(src, Color::Push);
270            true
271        }
272
273        // Infer right side.
274        (Some(Color::Pull | Color::Hoff), None) => {
275            node_color.insert(dst, Color::Pull);
276            true
277        }
278        (Some(Color::Comp | Color::Push), None) => {
279            node_color.insert(dst, Color::Push);
280            true
281        }
282
283        // Both sides already specified.
284        (Some(Color::Pull), Some(Color::Pull)) => true,
285        (Some(Color::Pull), Some(Color::Comp)) => true,
286        (Some(Color::Pull), Some(Color::Push)) => true,
287
288        (Some(Color::Comp), Some(Color::Pull)) => false,
289        (Some(Color::Comp), Some(Color::Comp)) => false,
290        (Some(Color::Comp), Some(Color::Push)) => true,
291
292        (Some(Color::Push), Some(Color::Pull)) => false,
293        (Some(Color::Push), Some(Color::Comp)) => false,
294        (Some(Color::Push), Some(Color::Push)) => true,
295
296        // Handoffs are not part of subgraphs.
297        (Some(Color::Hoff), Some(_)) => false,
298        (Some(_), Some(Color::Hoff)) => false,
299    };
300    can_connect
301}
302
303/// Topologically sorts subgraphs and injects intermediate identity subgraphs for `defer_tick`
304/// edges that go forward in the topological order (so they become back-edges).
305///
306/// Returns an error if there is an intra-tick cycle (i.e. the subgraph DAG has a cycle when
307/// tick-boundary edges are excluded).
308fn order_subgraphs(
309    partitioned_graph: &mut DfirGraph,
310    barrier_crossers: &BarrierCrossers,
311) -> Result<(), Diagnostic> {
312    // Build a subgraph-level directed graph, excluding tick-boundary edges.
313    let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
314
315    // Track which handoff edges are tick-boundary, keyed by (src_sg, dst_sg).
316    let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
317
318    // Iterate handoffs between subgraphs.
319    for (node_id, node) in partitioned_graph.nodes() {
320        if !matches!(node, GraphNode::Handoff { .. }) {
321            continue;
322        }
323        assert_eq!(1, partitioned_graph.node_successors(node_id).len());
324        let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
325
326        let succ_edge_delaytype = barrier_crossers
327            .edge_barrier_crossers
328            .get(succ_edge)
329            .copied();
330        // Tick edges are excluded from the topo sort — they are cross-tick by design.
331        if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
332            tick_edges.push((succ_edge, delay_type));
333            continue;
334        }
335
336        assert_eq!(1, partitioned_graph.node_predecessors(node_id).len());
337        let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
338
339        let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
340        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
341
342        sg_preds.entry(succ_sg).or_default().push(pred_sg);
343    }
344    // Include singleton reference edges.
345    for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
346        assert_ne!(pred, succ, "TODO(mingwei)");
347        let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
348        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
349        assert_ne!(pred_sg, succ_sg);
350        sg_preds.entry(succ_sg).or_default().push(pred_sg);
351    }
352
353    // Topological sort — rejects intra-tick cycles.
354    let topo_sort_order = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
355        sg_preds.get(&v).into_iter().flatten().copied()
356    });
357    let topo_sort_order = match topo_sort_order {
358        Ok(order) => order,
359        Err(cycle) => {
360            let span = cycle
361                .first()
362                .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
363                .map(|n| partitioned_graph.node(n).span())
364                .unwrap_or_else(Span::call_site);
365            return Err(Diagnostic::spanned(
366                span,
367                Level::Error,
368                "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
369            ));
370        }
371    };
372
373    // Build a position map for the topo sort order.
374    let sg_position: BTreeMap<GraphSubgraphId, usize> = topo_sort_order
375        .iter()
376        .enumerate()
377        .map(|(i, &sg_id)| (sg_id, i))
378        .collect();
379
380    // Process tick-boundary edges: inject intermediate identity subgraphs where needed.
381    // TODO(cleanup): The intermediate identity subgraph injection is a workaround. In the future,
382    // handoff buffers should be sufficient without needing an extra subgraph.
383    for (edge_id, delay_type) in tick_edges {
384        let (hoff, dst) = partitioned_graph.edge(edge_id);
385        // Ignore barriers within `loop {` blocks.
386        if partitioned_graph.node_loop(dst).is_some() {
387            continue;
388        }
389
390        assert_eq!(1, partitioned_graph.node_predecessors(hoff).len());
391        let src = partitioned_graph
392            .node_predecessor_nodes(hoff)
393            .next()
394            .unwrap();
395
396        let src_sg = partitioned_graph.node_subgraph(src).unwrap();
397        let dst_sg = partitioned_graph.node_subgraph(dst).unwrap();
398        let src_pos = sg_position[&src_sg];
399        let dst_pos = sg_position[&dst_sg];
400        let dst_span = partitioned_graph.node(dst).span();
401
402        // If tick edge goes forward in topo order, need to inject a buffer subgraph.
403        if src_pos <= dst_pos {
404            // Before: A (src) -> H -> B (dst)
405            // Then add intermediate identity:
406            let (new_node_id, new_edge_id) = partitioned_graph.insert_intermediate_node(
407                edge_id,
408                // TODO(mingwei): Proper span w/ `parse_quote_spanned!`?
409                GraphNode::Operator(parse_quote! { identity() }),
410            );
411            // Intermediate: A (src) -> H -> ID -> B (dst)
412            let hoff_node = GraphNode::Handoff {
413                src_span: dst_span,
414                dst_span,
415            };
416            let (hoff_node_id, _hoff_edge_id) =
417                partitioned_graph.insert_intermediate_node(new_edge_id, hoff_node);
418            // After: A (src) -> H -> ID -> H' -> B (dst)
419
420            // Create subgraph for the intermediate identity.
421            partitioned_graph
422                .insert_subgraph(vec![new_node_id])
423                .unwrap();
424
425            // Mark H' as a tick-boundary back-edge.
426            partitioned_graph.set_handoff_delay_type(hoff_node_id, delay_type);
427        } else {
428            // Already a back-edge (src after dst in topo order).
429            // Mark the original handoff H as a tick-boundary back-edge.
430            partitioned_graph.set_handoff_delay_type(hoff, delay_type);
431        }
432    }
433    Ok(())
434}
435
436/// Main method for this module. Partitions a flat [`DfirGraph`] into one with subgraphs.
437///
438/// Returns an error if an intra-tick cycle exists in the graph.
439pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
440    // Pre-find barrier crossers (input edges with a `DelayType`).
441    let mut barrier_crossers = find_barrier_crossers(&flat_graph);
442    let mut partitioned_graph = flat_graph;
443
444    // Partition into subgraphs.
445    make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
446
447    // Topologically order subgraphs and inject intermediate subgraphs for defer_tick edges.
448    order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
449
450    Ok(partitioned_graph)
451}