1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, HandoffKind, MODULE_BOUNDARY_NODE_STR, OperatorInstance,
24 PortIndexValue, SINGLETON_SLOT_NODE_STR, Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Clone, Debug, Serialize, Deserialize)]
32pub struct ResolvedHandoffRef {
33 pub node_id: Option<GraphNodeId>,
35 pub is_mut: bool,
37 pub access_group: Option<u32>,
39}
40
41#[derive(Default, Debug, Serialize, Deserialize)]
51pub struct DfirGraph {
52 nodes: SlotMap<GraphNodeId, GraphNode>,
54
55 #[serde(skip)]
58 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
59 operator_tag: SecondaryMap<GraphNodeId, String>,
61 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
63 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
65
66 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
68 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
70 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
72 root_loops: Vec<GraphLoopId>,
74 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
76
77 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
79
80 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
82 subgraph_toposort: Vec<GraphSubgraphId>,
84
85 node_handoff_references: SparseSecondaryMap<GraphNodeId, Vec<ResolvedHandoffRef>>,
87 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
89
90 handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
94}
95
96impl DfirGraph {
98 pub fn new() -> Self {
100 Default::default()
101 }
102}
103
104impl DfirGraph {
106 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
108 self.nodes.get(node_id).expect("Node not found.")
109 }
110
111 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
116 self.operator_instances.get(node_id)
117 }
118
119 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
121 self.node_varnames.get(node_id)
122 }
123
124 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
126 self.node_subgraph.get(node_id).copied()
127 }
128
129 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
131 self.graph.degree_in(node_id)
132 }
133
134 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
136 self.graph.degree_out(node_id)
137 }
138
139 pub fn node_successors(
141 &self,
142 src: GraphNodeId,
143 ) -> impl '_
144 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
145 + ExactSizeIterator
146 + FusedIterator
147 + Clone
148 + Debug {
149 self.graph.successors(src)
150 }
151
152 pub fn node_predecessors(
154 &self,
155 dst: GraphNodeId,
156 ) -> impl '_
157 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
158 + ExactSizeIterator
159 + FusedIterator
160 + Clone
161 + Debug {
162 self.graph.predecessors(dst)
163 }
164
165 pub fn node_successor_edges(
167 &self,
168 src: GraphNodeId,
169 ) -> impl '_
170 + DoubleEndedIterator<Item = GraphEdgeId>
171 + ExactSizeIterator
172 + FusedIterator
173 + Clone
174 + Debug {
175 self.graph.successor_edges(src)
176 }
177
178 pub fn node_predecessor_edges(
180 &self,
181 dst: GraphNodeId,
182 ) -> impl '_
183 + DoubleEndedIterator<Item = GraphEdgeId>
184 + ExactSizeIterator
185 + FusedIterator
186 + Clone
187 + Debug {
188 self.graph.predecessor_edges(dst)
189 }
190
191 pub fn node_successor_nodes(
193 &self,
194 src: GraphNodeId,
195 ) -> impl '_
196 + DoubleEndedIterator<Item = GraphNodeId>
197 + ExactSizeIterator
198 + FusedIterator
199 + Clone
200 + Debug {
201 self.graph.successor_vertices(src)
202 }
203
204 pub fn node_predecessor_nodes(
206 &self,
207 dst: GraphNodeId,
208 ) -> impl '_
209 + DoubleEndedIterator<Item = GraphNodeId>
210 + ExactSizeIterator
211 + FusedIterator
212 + Clone
213 + Debug {
214 self.graph.predecessor_vertices(dst)
215 }
216
217 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
219 self.nodes.keys()
220 }
221
222 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
224 self.nodes.iter()
225 }
226
227 pub fn insert_node(
229 &mut self,
230 node: GraphNode,
231 varname_opt: Option<Ident>,
232 loop_opt: Option<GraphLoopId>,
233 ) -> GraphNodeId {
234 let node_id = self.nodes.insert(node);
235 if let Some(varname) = varname_opt {
236 self.node_varnames.insert(node_id, Varname(varname));
237 }
238 if let Some(loop_id) = loop_opt {
239 self.node_loops.insert(node_id, loop_id);
240 self.loop_nodes[loop_id].push(node_id);
241 }
242 node_id
243 }
244
245 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
247 assert!(matches!(
248 self.nodes.get(node_id),
249 Some(GraphNode::Operator(_))
250 ));
251 let old_inst = self.operator_instances.insert(node_id, op_inst);
252 assert!(old_inst.is_none());
253 }
254
255 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
257 let mut op_insts = Vec::new();
262 let mut handoff_nodes: Vec<(GraphNodeId, HandoffKind, Span)> = Vec::new();
264
265 for (node_id, node) in self.nodes() {
266 let GraphNode::Operator(operator) = node else {
267 continue;
268 };
269 if self.node_op_inst(node_id).is_some() {
270 continue;
271 };
272
273 let handoff_kind = match &*operator.name_string() {
275 "handoff" => Some(HandoffKind::Vec),
276 "singleton" => Some(HandoffKind::Singleton),
277 "optional" => Some(HandoffKind::Optional),
278 _ => None,
279 };
280 if let Some(kind) = handoff_kind {
281 if !operator.args.is_empty() {
282 diagnostics.push(Diagnostic::spanned(
283 operator.path.span(),
284 Level::Error,
285 format!("`{}` takes no arguments.", operator.name_string()),
286 ));
287 }
288 if operator.type_arguments().is_some() {
289 diagnostics.push(Diagnostic::spanned(
290 operator.path.span(),
291 Level::Error,
292 format!("`{}` takes no generic arguments.", operator.name_string()),
293 ));
294 }
295 handoff_nodes.push((node_id, kind, operator.path.span()));
296 continue;
297 }
298
299 let Some(op_constraints) = find_op_op_constraints(operator) else {
301 diagnostics.push(Diagnostic::spanned(
302 operator.path.span(),
303 Level::Error,
304 format!("Unknown operator `{}`", operator.name_string()),
305 ));
306 continue;
307 };
308
309 let (input_ports, output_ports) = {
311 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
312 .node_predecessors(node_id)
313 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
314 .collect();
315 input_edges.sort();
317 let input_ports: Vec<PortIndexValue> = input_edges
318 .into_iter()
319 .map(|(port, _pred)| port)
320 .cloned()
321 .collect();
322
323 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
325 .node_successors(node_id)
326 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
327 .collect();
328 output_edges.sort();
330 let output_ports: Vec<PortIndexValue> = output_edges
331 .into_iter()
332 .map(|(port, _succ)| port)
333 .cloned()
334 .collect();
335
336 (input_ports, output_ports)
337 };
338
339 let generics = get_operator_generics(diagnostics, operator);
341 {
343 let generics_span = generics
345 .generic_args
346 .as_ref()
347 .map(Spanned::span)
348 .unwrap_or_else(|| operator.path.span());
349
350 if !op_constraints
351 .persistence_args
352 .contains(&generics.persistence_args.len())
353 {
354 diagnostics.push(Diagnostic::spanned(
355 generics.persistence_args_span().unwrap_or(generics_span),
356 Level::Error,
357 format!(
358 "`{}` should have {} persistence lifetime arguments, actually has {}.",
359 op_constraints.name,
360 op_constraints.persistence_args.human_string(),
361 generics.persistence_args.len()
362 ),
363 ));
364 }
365 if !op_constraints.type_args.contains(&generics.type_args.len()) {
366 diagnostics.push(Diagnostic::spanned(
367 generics.type_args_span().unwrap_or(generics_span),
368 Level::Error,
369 format!(
370 "`{}` should have {} generic type arguments, actually has {}.",
371 op_constraints.name,
372 op_constraints.type_args.human_string(),
373 generics.type_args.len()
374 ),
375 ));
376 }
377 }
378
379 op_insts.push((
380 node_id,
381 OperatorInstance {
382 op_constraints,
383 input_ports,
384 output_ports,
385 singletons_referenced: operator.singletons_referenced.clone(),
386 generics,
387 arguments_pre: operator.args.clone(),
388 arguments_raw: operator.args_raw.clone(),
389 },
390 ));
391 }
392
393 for (node_id, op_inst) in op_insts {
394 self.insert_node_op_inst(node_id, op_inst);
395 }
396
397 for (node_id, kind, span) in handoff_nodes {
399 self.nodes[node_id] = GraphNode::Handoff {
400 kind,
401 src_span: span,
402 dst_span: span,
403 };
404 }
405 }
406
407 pub fn insert_intermediate_node(
419 &mut self,
420 edge_id: GraphEdgeId,
421 new_node: GraphNode,
422 ) -> (GraphNodeId, GraphEdgeId) {
423 let span = Some(new_node.span());
424
425 let op_inst_opt = 'oc: {
427 let GraphNode::Operator(operator) = &new_node else {
428 break 'oc None;
429 };
430 let Some(op_constraints) = find_op_op_constraints(operator) else {
431 break 'oc None;
432 };
433 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
434
435 let mut dummy_diagnostics = Diagnostics::new();
436 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
437 assert!(dummy_diagnostics.is_empty());
438
439 Some(OperatorInstance {
440 op_constraints,
441 input_ports: vec![input_port],
442 output_ports: vec![output_port],
443 singletons_referenced: operator.singletons_referenced.clone(),
444 generics,
445 arguments_pre: operator.args.clone(),
446 arguments_raw: operator.args_raw.clone(),
447 })
448 };
449
450 let node_id = self.nodes.insert(new_node);
452 if let Some(op_inst) = op_inst_opt {
454 self.operator_instances.insert(node_id, op_inst);
455 }
456 let (e0, e1) = self
458 .graph
459 .insert_intermediate_vertex(node_id, edge_id)
460 .unwrap();
461
462 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
464 self.ports
465 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
466 self.ports
467 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
468
469 (node_id, e1)
470 }
471
472 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
475 assert_eq!(
476 1,
477 self.node_degree_in(node_id),
478 "Removed intermediate node must have one predecessor"
479 );
480 assert_eq!(
481 1,
482 self.node_degree_out(node_id),
483 "Removed intermediate node must have one successor"
484 );
485 assert!(
486 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
487 "Should not remove intermediate node after subgraph partitioning"
488 );
489
490 assert!(self.nodes.remove(node_id).is_some());
491 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
492 self.graph.remove_intermediate_vertex(node_id).unwrap();
493 self.operator_instances.remove(node_id);
494 self.node_varnames.remove(node_id);
495
496 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
497 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
498 self.ports.insert(new_edge_id, (src_port, dst_port));
499 }
500
501 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
507 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
508 return Some(Color::Hoff);
509 }
510
511 if let GraphNode::Operator(op) = self.node(node_id)
513 && (op.name_string() == "resolve_futures_blocking"
514 || op.name_string() == "resolve_futures_blocking_ordered")
515 {
516 return Some(Color::Push);
517 }
518
519 let inn_degree = self.node_predecessor_nodes(node_id).len();
521 let out_degree = self.node_successor_nodes(node_id).len();
523
524 match (inn_degree, out_degree) {
525 (0, 0) => None, (0, 1) => Some(Color::Pull),
527 (1, 0) => Some(Color::Push),
528 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
530 (0 | 1, _many) => Some(Color::Push),
531 (_many, _to_many) => Some(Color::Comp),
532 }
533 }
534
535 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
537 self.operator_tag.insert(node_id, tag);
538 }
539}
540
541impl DfirGraph {
543 pub fn set_node_handoff_references(
546 &mut self,
547 node_id: GraphNodeId,
548 singletons_referenced: Vec<ResolvedHandoffRef>,
549 ) -> Option<Vec<ResolvedHandoffRef>> {
550 self.node_handoff_references
551 .insert(node_id, singletons_referenced)
552 }
553
554 pub fn node_handoff_references(&self, node_id: GraphNodeId) -> &[ResolvedHandoffRef] {
557 self.node_handoff_references
558 .get(node_id)
559 .map(std::ops::Deref::deref)
560 .unwrap_or_default()
561 }
562
563 pub fn node_handoff_reference_groups(&self) -> NodeHandoffReferenceGroups<'_> {
565 let mut handoff_references = NodeHandoffReferenceGroups::new();
566 for node_id in self.node_ids() {
567 if let GraphNode::Operator(operator) = self.node(node_id) {
568 let resolved = self.node_handoff_references(node_id);
569 for (resolved_ref, ref_token) in
570 resolved.iter().zip(operator.singletons_referenced.iter())
571 {
572 if let Some(target_nid) = resolved_ref.node_id {
573 handoff_references
574 .entry(target_nid)
575 .or_default()
576 .entry(resolved_ref.access_group)
577 .or_default()
578 .push((node_id, resolved_ref, ref_token.span()));
579 }
580 }
581 }
582 }
583 handoff_references
584 }
585}
586
587pub type NodeHandoffReferenceGroups<'a> =
590 BTreeMap<GraphNodeId, BTreeMap<Option<u32>, Vec<(GraphNodeId, &'a ResolvedHandoffRef, Span)>>>;
591
592impl DfirGraph {
594 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
602 let mod_bound_nodes = self
603 .nodes()
604 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
605 .map(|(nid, _node)| nid)
606 .collect::<Vec<_>>();
607
608 for mod_bound_node in mod_bound_nodes {
609 self.remove_module_boundary(mod_bound_node)?;
610 }
611
612 Ok(())
613 }
614
615 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
619 assert!(
620 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
621 "Should not remove intermediate node after subgraph partitioning"
622 );
623
624 let mut mod_pred_ports = BTreeMap::new();
625 let mut mod_succ_ports = BTreeMap::new();
626
627 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
628 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
629 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
630 }
631
632 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
633 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
634 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
635 }
636
637 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
638 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
639 {
640 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
642 panic!();
643 };
644
645 if *input {
646 return Err(Diagnostic {
647 span: *import_expr,
648 level: Level::Error,
649 message: format!(
650 "The ports into the module did not match. input: {:?}, expected: {:?}",
651 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
652 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
653 ),
654 });
655 } else {
656 return Err(Diagnostic {
657 span: *import_expr,
658 level: Level::Error,
659 message: format!(
660 "The ports out of the module did not match. output: {:?}, expected: {:?}",
661 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
662 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
663 ),
664 });
665 }
666 }
667
668 for (port, (pred_edge, pred_port)) in mod_pred_ports {
669 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
670
671 let (src, _) = self.edge(pred_edge);
672 let (_, dst) = self.edge(succ_edge);
673 self.remove_edge(pred_edge);
674 self.remove_edge(succ_edge);
675
676 let new_edge_id = self.graph.insert_edge(src, dst);
677 self.ports.insert(new_edge_id, (pred_port, succ_port));
678 }
679
680 self.graph.remove_vertex(mod_bound_node);
681 self.nodes.remove(mod_bound_node);
682
683 Ok(())
684 }
685}
686
687impl DfirGraph {
689 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
691 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
692 (src, dst)
693 }
694
695 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
697 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
698 (src_port, dst_port)
699 }
700
701 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
703 self.graph.edge_ids()
704 }
705
706 pub fn edges(
708 &self,
709 ) -> impl '_
710 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
711 + FusedIterator
712 + Clone
713 + Debug {
714 self.graph.edges()
715 }
716
717 pub fn insert_edge(
719 &mut self,
720 src: GraphNodeId,
721 src_port: PortIndexValue,
722 dst: GraphNodeId,
723 dst_port: PortIndexValue,
724 ) -> GraphEdgeId {
725 let edge_id = self.graph.insert_edge(src, dst);
726 self.ports.insert(edge_id, (src_port, dst_port));
727 edge_id
728 }
729
730 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
732 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
733 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
734 }
735}
736
737impl DfirGraph {
739 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
741 self.subgraph_nodes
742 .get(subgraph_id)
743 .expect("Subgraph not found.")
744 }
745
746 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
748 self.subgraph_nodes.keys()
749 }
750
751 pub fn subgraph_toposort(&self) -> &[GraphSubgraphId] {
753 &self.subgraph_toposort
754 }
755
756 pub fn set_subgraph_toposort(&mut self, order: Vec<GraphSubgraphId>) {
758 self.subgraph_toposort = order;
759 }
760
761 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
763 self.subgraph_nodes.iter()
764 }
765
766 pub fn insert_subgraph(
768 &mut self,
769 node_ids: Vec<GraphNodeId>,
770 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
771 for &node_id in node_ids.iter() {
773 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
774 return Err((node_id, old_sg_id));
775 }
776 }
777 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
778 for &node_id in node_ids.iter() {
779 self.node_subgraph.insert(node_id, sg_id);
780 }
781 node_ids
782 });
783
784 Ok(subgraph_id)
785 }
786
787 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
789 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
790 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
791 true
792 } else {
793 false
794 }
795 }
796
797 pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
799 self.handoff_delay_type.get(node_id).copied()
800 }
801
802 pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
804 self.handoff_delay_type.insert(node_id, delay_type);
805 }
806
807 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
809 subgraph_nodes
810 .iter()
811 .position(|&node_id| {
812 self.node_color(node_id)
813 .is_some_and(|color| Color::Pull != color)
814 })
815 .unwrap_or(subgraph_nodes.len())
816 }
817}
818
819impl DfirGraph {
821 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
823 let name = match &self.nodes[node_id] {
824 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
825 GraphNode::Handoff {
826 kind: HandoffKind::Vec,
827 ..
828 } => format!(
829 "hoff_{:?}_{}",
830 node_id.data(),
831 if is_pred { "recv" } else { "send" }
832 ),
833 GraphNode::Handoff {
834 kind: HandoffKind::Singleton | HandoffKind::Optional,
835 ..
836 } => format!(
837 "singleton_{:?}_{}",
838 node_id.data(),
839 if is_pred { "recv" } else { "send" }
840 ),
841 GraphNode::ModuleBoundary { .. } => panic!(),
842 };
843 let span = match (is_pred, &self.nodes[node_id]) {
844 (_, GraphNode::Operator(operator)) => operator.span(),
845 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
846 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
847 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
848 };
849 Ident::new(&name, span)
850 }
851
852 fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
854 Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
855 }
856
857 fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
859 Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
860 }
861
862 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<TokenStream> {
871 self.node_handoff_references(node_id)
872 .iter()
873 .map(|resolved_ref| {
874 let ref_node_id = resolved_ref
876 .node_id
877 .expect("Expected singleton to be resolved but was not, this is a bug.");
878 let is_mut = resolved_ref.is_mut;
879 match self.node(ref_node_id) {
880 GraphNode::Handoff {
881 kind: HandoffKind::Singleton,
882 ..
883 } => {
884 let buf_ident = self.hoff_buf_ident(ref_node_id, span);
885 if is_mut {
886 quote_spanned! {span=> #buf_ident.as_mut().unwrap() }
887 } else {
888 quote_spanned! {span=> #buf_ident.as_ref().unwrap() }
889 }
890 }
891 GraphNode::Handoff {
892 kind: HandoffKind::Optional | HandoffKind::Vec,
893 ..
894 } => {
895 let buf_ident = self.hoff_buf_ident(ref_node_id, span);
896 if is_mut {
897 quote_spanned! {span=> &mut #buf_ident }
898 } else {
899 quote_spanned! {span=> &#buf_ident }
900 }
901 }
902 _ => {
903 unreachable!("Only handoff nodes should be reachable as handoff references")
904 }
905 }
906 })
907 .collect::<Vec<_>>()
908 }
909
910 fn helper_collect_subgraph_handoffs(
913 &self,
914 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
915 let mut subgraph_handoffs: SecondaryMap<
917 GraphSubgraphId,
918 (Vec<GraphNodeId>, Vec<GraphNodeId>),
919 > = self
920 .subgraph_nodes
921 .keys()
922 .map(|k| (k, Default::default()))
923 .collect();
924
925 for (hoff_id, hoff) in self.nodes() {
927 if !matches!(hoff, GraphNode::Handoff { .. }) {
928 continue;
929 }
930 for (_edge, succ_id) in self.node_successors(hoff_id) {
932 let succ_sg = self
933 .node_subgraph(succ_id)
934 .expect("bug: successor not in subgraph, may be a doubled/adjacent handoff");
935 subgraph_handoffs[succ_sg].0.push(hoff_id);
936 }
937 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
939 let pred_sg = self
940 .node_subgraph(pred_id)
941 .expect("bug: predecessor not in subgraph, may be a doubled/adjacent handoff");
942 subgraph_handoffs[pred_sg].1.push(hoff_id);
943 }
944 }
945
946 subgraph_handoffs
947 }
948
949 pub fn as_code(
964 &self,
965 root: &TokenStream,
966 include_type_guards: bool,
967 prefix: TokenStream,
968 diagnostics: &mut Diagnostics,
969 ) -> Result<TokenStream, Diagnostics> {
970 self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
971 }
972
973 pub fn as_code_with_options(
982 &self,
983 root: &TokenStream,
984 include_type_guards: bool,
985 include_meta: bool,
986 prefix: TokenStream,
987 diagnostics: &mut Diagnostics,
988 ) -> Result<TokenStream, Diagnostics> {
989 let df = Ident::new(GRAPH, Span::call_site());
990 let context = Ident::new(CONTEXT, Span::call_site());
991 let bump_ident = Ident::new("__dfir_bump", Span::call_site());
993
994 let handoff_nodes = self
996 .nodes
997 .iter()
998 .filter_map(|(node_id, node)| match node {
999 &GraphNode::Handoff {
1000 kind,
1001 src_span,
1002 dst_span,
1003 } => Some((node_id, kind, (src_span, dst_span))),
1004 GraphNode::Operator(_) => None,
1005 GraphNode::ModuleBoundary { .. } => panic!(),
1006 })
1007 .collect::<Vec<_>>();
1008
1009 let back_edge_hoffs_and_lazyness = handoff_nodes
1013 .iter()
1014 .map(|&(node_id, _, _)| node_id)
1015 .filter_map(|node_id| {
1016 if let Some(delay_type) = self.handoff_delay_type(node_id) {
1017 assert!(
1018 matches!(delay_type, DelayType::Tick | DelayType::TickLazy),
1019 "Handoff `DelayType` must be either `Tick` or `TickLazy` (or unset)."
1020 );
1021 Some((node_id, matches!(delay_type, DelayType::TickLazy)))
1022 } else {
1023 None
1024 }
1025 })
1026 .collect::<SparseSecondaryMap<_, _>>();
1027
1028 let back_buffer_idents_laziness = handoff_nodes
1030 .iter()
1031 .filter_map(|&(hoff_id, _kind, (src_span, dst_span))| {
1032 back_edge_hoffs_and_lazyness.get(hoff_id).map(|&is_lazy| {
1033 let span = src_span.join(dst_span).unwrap_or(src_span);
1034 let back_ident = self.hoff_back_ident(hoff_id, span);
1035 let buf_ident = self.hoff_buf_ident(hoff_id, span);
1036 (back_ident, buf_ident, is_lazy)
1037 })
1038 })
1039 .collect::<Vec<_>>();
1040
1041 let back_edge_swap_code = handoff_nodes
1045 .iter()
1046 .filter(|&&(node_id, _kind, _)| back_edge_hoffs_and_lazyness.contains_key(node_id))
1047 .map(|&(hoff_id, _kind, _)| {
1048 let span = self.nodes[hoff_id].span();
1049 let buf_ident = self.hoff_buf_ident(hoff_id, span);
1050 let back_ident = self.hoff_back_ident(hoff_id, span);
1051 quote_spanned! {span=>
1052 ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1053 }
1054 })
1055 .collect::<Vec<_>>();
1056
1057 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
1059
1060 let all_subgraphs: Vec<_> = self
1062 .subgraph_toposort()
1063 .iter()
1064 .map(|&sg_id| (sg_id, self.subgraph(sg_id)))
1065 .collect();
1066
1067 let mut op_prologue_code = Vec::new();
1071 let mut op_tick_end_code = Vec::new();
1072 let mut subgraph_blocks = Vec::new();
1073 {
1074 for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1075 let sg_metrics_ffi = subgraph_id.data().as_ffi();
1076 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1077
1078 let recv_port_idents: Vec<Ident> = recv_hoffs
1080 .iter()
1081 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1082 .collect();
1083 let send_port_idents: Vec<Ident> = send_hoffs
1084 .iter()
1085 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1086 .collect();
1087
1088 let recv_buf_idents: Vec<Ident> = recv_hoffs
1090 .iter()
1091 .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1092 .collect();
1093 let send_buf_idents: Vec<Ident> = send_hoffs
1094 .iter()
1095 .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1096 .collect();
1097
1098 let recv_kinds = recv_hoffs
1100 .iter()
1101 .map(|&hoff_id| {
1102 let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1103 panic!()
1104 };
1105 *kind
1106 })
1107 .collect::<Vec<_>>();
1108 let send_kinds = send_hoffs
1109 .iter()
1110 .map(|&hoff_id| {
1111 let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1112 panic!()
1113 };
1114 *kind
1115 })
1116 .collect::<Vec<_>>();
1117
1118 let recv_port_code: Vec<TokenStream> = recv_port_idents
1122 .iter()
1123 .zip(recv_buf_idents.iter())
1124 .zip(recv_kinds.iter())
1125 .zip(recv_hoffs.iter())
1126 .map(|(((port_ident, buf_ident), &kind), &hoff_id)| {
1127 let hoff_ffi = hoff_id.data().as_ffi();
1128 let work_done = Ident::new("__dfir_work_done", Span::call_site());
1132 let metrics = Ident::new("__dfir_metrics", Span::call_site());
1133
1134 let (len_expr, drain_expr) = match kind {
1136 HandoffKind::Singleton | HandoffKind::Optional => (
1137 quote! { if #buf_ident.is_some() { 1usize } else { 0usize } },
1138 quote! { #root::dfir_pipes::pull::iter(#buf_ident.take().into_iter()) },
1139 ),
1140 HandoffKind::Vec => {
1141 let drain_ident = if back_edge_hoffs_and_lazyness.contains_key(hoff_id) {
1145 &self.hoff_back_ident(hoff_id, buf_ident.span())
1146 } else {
1147 buf_ident
1148 };
1149 (
1150 quote! { #drain_ident.len() },
1151 quote! { #root::dfir_pipes::pull::iter(#drain_ident.drain(..)) },
1152 )
1153 }
1154 };
1155
1156 quote_spanned! {port_ident.span()=>
1157 {
1158 let hoff_len = #len_expr;
1159 if hoff_len > 0 {
1160 #work_done = true;
1161 }
1162 let hoff_metrics = &#metrics.handoffs[
1163 #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1164 ];
1165 hoff_metrics.total_items_count.update(|x| x + hoff_len);
1166 hoff_metrics.curr_items_count.set(hoff_len);
1167 }
1168 let #port_ident = #drain_expr;
1169 }
1170 })
1171 .collect();
1172
1173 let send_port_code: Vec<TokenStream> = send_port_idents
1175 .iter()
1176 .zip(send_buf_idents.iter())
1177 .zip(send_kinds.iter())
1178 .map(|((port_ident, buf_ident), &kind)| {
1179 match kind {
1180 HandoffKind::Singleton => {
1181 quote_spanned! {port_ident.span()=>
1183 let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1184 if #buf_ident.replace(__item).is_some() {
1185 panic!("singleton() received more than one item");
1186 }
1187 });
1188 }
1189 }
1190 HandoffKind::Optional => {
1191 quote_spanned! {port_ident.span()=>
1193 let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1194 if #buf_ident.replace(__item).is_some() {
1195 panic!("optional() received more than one item");
1196 }
1197 });
1198 }
1199 }
1200 HandoffKind::Vec => {
1201 quote_spanned! {port_ident.span()=>
1202 let #port_ident = #root::dfir_pipes::push::for_each(|item| { #buf_ident.push(item); });
1204 }
1205 }
1206 }
1207 })
1208 .collect();
1209
1210 let loop_id = self.node_loop(subgraph_nodes[0]);
1212
1213 let mut subgraph_op_iter_code = Vec::new();
1214 let mut subgraph_op_iter_after_code = Vec::new();
1215 {
1216 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1217
1218 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1219 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1220
1221 for (idx, &node_id) in nodes_iter.enumerate() {
1222 let node = &self.nodes[node_id];
1223 assert!(
1224 matches!(node, GraphNode::Operator(_)),
1225 "Handoffs are not part of subgraphs."
1226 );
1227 let op_inst = &self.operator_instances[node_id];
1228
1229 let op_span = node.span();
1230 let op_name = op_inst.op_constraints.name;
1231 let root = change_spans(root.clone(), op_span);
1233 let op_constraints = OPERATORS
1234 .iter()
1235 .find(|op| op_name == op.name)
1236 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1237
1238 let ident = self.node_as_ident(node_id, false);
1239
1240 {
1241 let mut input_edges = self
1244 .graph
1245 .predecessor_edges(node_id)
1246 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1247 .collect::<Vec<_>>();
1248 input_edges.sort();
1250
1251 let inputs = input_edges
1252 .iter()
1253 .map(|&(_port, edge_id)| {
1254 let (pred, _) = self.edge(edge_id);
1255 self.node_as_ident(pred, true)
1256 })
1257 .collect::<Vec<_>>();
1258
1259 let mut output_edges = self
1261 .graph
1262 .successor_edges(node_id)
1263 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1264 .collect::<Vec<_>>();
1265 output_edges.sort();
1267
1268 let outputs = output_edges
1269 .iter()
1270 .map(|&(_port, edge_id)| {
1271 let (_, succ) = self.edge(edge_id);
1272 self.node_as_ident(succ, false)
1273 })
1274 .collect::<Vec<_>>();
1275
1276 let is_pull = idx < pull_to_push_idx;
1277
1278 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1287 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1288
1289 let singletons_resolved =
1290 self.helper_resolve_singletons(node_id, op_span);
1291
1292 let arguments = &process_singletons::postprocess_singletons(
1293 op_inst.arguments_raw.clone(),
1294 singletons_resolved,
1295 );
1296
1297 let source_tag = 'a: {
1298 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1299 break 'a tag;
1300 }
1301
1302 if proc_macro::is_available() {
1303 let op_span = op_span.unwrap();
1304 break 'a format!(
1305 "loc_{}_{}_{}_{}_{}",
1306 crate::pretty_span::make_source_path_relative(
1307 &op_span.file()
1308 )
1309 .display()
1310 .to_string()
1311 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1312 op_span.start().line(),
1313 op_span.start().column(),
1314 op_span.end().line(),
1315 op_span.end().column(),
1316 );
1317 }
1318
1319 format!(
1320 "loc_nopath_{}_{}_{}_{}",
1321 op_span.start().line,
1322 op_span.start().column,
1323 op_span.end().line,
1324 op_span.end().column
1325 )
1326 };
1327
1328 let work_fn = format_ident!(
1329 "{}__{}__{}",
1330 ident,
1331 op_name,
1332 source_tag,
1333 span = op_span
1334 );
1335 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1336
1337 let context_args = WriteContextArgs {
1338 root: &root,
1339 df_ident: df_local,
1340 context,
1341 subgraph_id,
1342 node_id,
1343 loop_id,
1344 op_span,
1345 op_tag: self.operator_tag.get(node_id).cloned(),
1346 work_fn: &work_fn,
1347 work_fn_async: &work_fn_async,
1348 ident: &ident,
1349 is_pull,
1350 inputs: &inputs,
1351 outputs: &outputs,
1352 op_name,
1353 op_inst,
1354 arguments,
1355 };
1356
1357 let write_result =
1358 (op_constraints.write_fn)(&context_args, diagnostics);
1359 let OperatorWriteOutput {
1360 write_prologue,
1361 write_iterator,
1362 write_iterator_after,
1363 write_tick_end,
1364 } = write_result.unwrap_or_else(|()| {
1365 assert!(
1366 diagnostics.has_error(),
1367 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1368 op_name,
1369 );
1370 OperatorWriteOutput {
1371 write_iterator: null_write_iterator_fn(&context_args),
1372 ..Default::default()
1373 }
1374 });
1375
1376 op_prologue_code.push(syn::parse_quote! {
1377 #[allow(non_snake_case)]
1378 #[inline(always)]
1379 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1380 thunk()
1381 }
1382
1383 #[allow(non_snake_case)]
1384 #[inline(always)]
1385 async fn #work_fn_async<T>(
1386 thunk: impl ::std::future::Future<Output = T>,
1387 ) -> T {
1388 thunk.await
1389 }
1390 });
1391 op_prologue_code.push(write_prologue);
1392 op_tick_end_code.push(write_tick_end);
1393 subgraph_op_iter_code.push(write_iterator);
1394
1395 if include_type_guards {
1396 let type_guard = if is_pull {
1397 quote_spanned! {op_span=>
1398 let #ident = {
1399 #[allow(non_snake_case)]
1400 #[inline(always)]
1401 pub fn #work_fn<Item, Input>(input: Input)
1402 -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1403 where
1404 Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1405 {
1406 #root::pin_project_lite::pin_project! {
1407 #[repr(transparent)]
1408 struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1409 #[pin]
1410 inner: Input
1411 }
1412 }
1413
1414 impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1415 where
1416 Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1417 {
1418 type Ctx<'ctx> = Input::Ctx<'ctx>;
1419
1420 type Item = Item;
1421 type Meta = Input::Meta;
1422 type CanPend = Input::CanPend;
1423 type CanEnd = Input::CanEnd;
1424
1425 #[inline(always)]
1426 fn pull(
1427 self: ::std::pin::Pin<&mut Self>,
1428 ctx: &mut Self::Ctx<'_>,
1429 ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1430 #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1431 }
1432
1433 #[inline(always)]
1434 fn size_hint(&self) -> (usize, Option<usize>) {
1435 #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1436 }
1437 }
1438
1439 Pull {
1440 inner: input
1441 }
1442 }
1443 #work_fn::<_, _>( #ident )
1444 };
1445 }
1446 } else {
1447 quote_spanned! {op_span=>
1448 let #ident = {
1449 #[allow(non_snake_case)]
1450 #[inline(always)]
1451 pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1452 where
1453 Psh: #root::dfir_pipes::push::Push<Item, ()>
1454 {
1455 #root::pin_project_lite::pin_project! {
1456 #[repr(transparent)]
1457 struct PushGuard<Psh> {
1458 #[pin]
1459 inner: Psh,
1460 }
1461 }
1462
1463 impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1464 where
1465 Psh: #root::dfir_pipes::push::Push<Item, ()>,
1466 {
1467 type Ctx<'ctx> = Psh::Ctx<'ctx>;
1468
1469 type CanPend = Psh::CanPend;
1470
1471 #[inline(always)]
1472 fn poll_ready(
1473 self: ::std::pin::Pin<&mut Self>,
1474 ctx: &mut Self::Ctx<'_>,
1475 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1476 #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1477 }
1478
1479 #[inline(always)]
1480 fn start_send(
1481 self: ::std::pin::Pin<&mut Self>,
1482 item: Item,
1483 meta: (),
1484 ) {
1485 #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1486 }
1487
1488 #[inline(always)]
1489 fn poll_finalize(
1490 self: ::std::pin::Pin<&mut Self>,
1491 ctx: &mut Self::Ctx<'_>,
1492 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1493 #root::dfir_pipes::push::Push::poll_finalize(self.project().inner, ctx)
1494 }
1495
1496 #[inline(always)]
1497 fn size_hint(
1498 self: ::std::pin::Pin<&mut Self>,
1499 hint: (usize, Option<usize>),
1500 ) {
1501 #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1502 }
1503 }
1504
1505 PushGuard {
1506 inner: psh
1507 }
1508 }
1509 #work_fn( #ident )
1510 };
1511 }
1512 };
1513 subgraph_op_iter_code.push(type_guard);
1514 }
1515 subgraph_op_iter_after_code.push(write_iterator_after);
1516 }
1517 }
1518
1519 {
1520 let pull_ident = if 0 < pull_to_push_idx {
1522 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1523 } else {
1524 recv_port_idents[0].clone()
1526 };
1527
1528 #[rustfmt::skip]
1529 let push_ident = if let Some(&node_id) =
1530 subgraph_nodes.get(pull_to_push_idx)
1531 {
1532 self.node_as_ident(node_id, false)
1533 } else if 1 == send_port_idents.len() {
1534 send_port_idents[0].clone()
1536 } else {
1537 diagnostics.push(Diagnostic::spanned(
1538 pull_ident.span(),
1539 Level::Error,
1540 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1541 ));
1542 continue;
1543 };
1544
1545 let pivot_span = pull_ident
1547 .span()
1548 .join(push_ident.span())
1549 .unwrap_or_else(|| push_ident.span());
1550 let pivot_fn_ident = Ident::new(
1551 &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1552 pivot_span,
1553 );
1554 let root = change_spans(root.clone(), pivot_span);
1555 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1556 #[inline(always)]
1557 fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1558 -> impl ::std::future::Future<Output = ()>
1559 where
1560 Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1561 Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1562 {
1563 #root::dfir_pipes::pull::Pull::send_push(pull, push)
1564 }
1565 (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1566 });
1567 }
1568 };
1569
1570 let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1574
1575 let send_metrics_code = send_hoffs
1577 .iter()
1578 .zip(send_buf_idents.iter())
1579 .zip(send_kinds.iter())
1580 .map(|((&hoff_id, buf_ident), &kind)| {
1581 let hoff_ffi = hoff_id.data().as_ffi();
1582 let len_expr = match kind {
1583 HandoffKind::Singleton | HandoffKind::Optional => {
1584 quote! { if #buf_ident.is_some() { 1 } else { 0 } }
1585 }
1586 HandoffKind::Vec => {
1587 quote! { #buf_ident.len() }
1588 }
1589 };
1590 quote! {
1591 __dfir_metrics.handoffs[
1592 #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1593 ].curr_items_count.set(#len_expr);
1594 }
1595 })
1596 .collect::<Vec<_>>();
1597
1598 let send_hoff_make_code = send_buf_idents.iter()
1600 .zip(send_kinds.iter())
1601 .zip(send_hoffs.iter())
1602 .map(|((buf_ident, &kind), &hoff_id)| {
1603 let span = buf_ident.span();
1604 if back_edge_hoffs_and_lazyness.contains_key(hoff_id) {
1605 quote_spanned! {span=>
1608 #buf_ident.clear();
1609 }
1610 } else {
1611 match kind {
1612 HandoffKind::Vec => quote_spanned! {span=>
1613 let mut #buf_ident = #root::bumpalo::collections::Vec::new_in(&#bump_ident);
1614 },
1615 HandoffKind::Singleton | HandoffKind::Optional => quote_spanned! {span=>
1616 let mut #buf_ident = ::std::option::Option::None;
1617 },
1618 }
1619 }
1620 });
1621 let recv_hoff_drop_code = recv_buf_idents
1625 .iter()
1626 .zip(recv_hoffs.iter())
1627 .filter(|&(_, &hoff_id)| !back_edge_hoffs_and_lazyness.contains_key(hoff_id))
1628 .map(|(buf_ident, _)| {
1629 let span = buf_ident.span();
1630 quote_spanned! {span=>
1631 let _ = #buf_ident;
1632 }
1633 });
1634
1635 subgraph_blocks.push(quote! {
1636 #( #send_hoff_make_code )*
1638
1639 let #sg_fut_ident = async {
1640 let #context = &#df;
1641 #( #recv_port_code )*
1642 #( #send_port_code )*
1643 #( #subgraph_op_iter_code )*
1644 #( #subgraph_op_iter_after_code )*
1645 };
1646 {
1647 let sg_metrics = &__dfir_metrics.subgraphs[
1649 #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1650 ];
1651 #root::scheduled::metrics::InstrumentSubgraph::new(
1652 #sg_fut_ident, sg_metrics
1653 ).await;
1654 sg_metrics.total_run_count.update(|x| x + 1);
1655
1656 #( #send_metrics_code )*
1658
1659 #( #recv_hoff_drop_code )*
1661 }
1662 });
1663
1664 }
1667 }
1668
1669 if diagnostics.has_error() {
1670 return Err(std::mem::take(diagnostics));
1671 }
1672 let _ = diagnostics; let (meta_graph_arg, diagnostics_arg) = if include_meta {
1675 let meta_graph_json = serde_json::to_string(&self).unwrap();
1676 let meta_graph_json = Literal::string(&meta_graph_json);
1677
1678 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1679 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1680 let diagnostics_json = Literal::string(&diagnostics_json);
1681
1682 (
1683 quote! { Some(#meta_graph_json) },
1684 quote! { Some(#diagnostics_json) },
1685 )
1686 } else {
1687 (quote! { None }, quote! { None })
1688 };
1689
1690 let metrics_init_code = {
1692 let handoff_inits = handoff_nodes.iter().map(|&(node_id, _, _)| {
1693 let ffi = node_id.data().as_ffi();
1694 quote! {
1695 dfir_metrics.handoffs.insert(
1696 #root::slotmap::KeyData::from_ffi(#ffi).into(),
1697 ::std::default::Default::default(),
1698 );
1699 }
1700 });
1701 let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1702 let ffi = sg_id.data().as_ffi();
1703 quote! {
1704 dfir_metrics.subgraphs.insert(
1705 #root::slotmap::KeyData::from_ffi(#ffi).into(),
1706 ::std::default::Default::default(),
1707 );
1708 }
1709 });
1710 handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1711 };
1712
1713 let back_buffer_idents = back_buffer_idents_laziness
1715 .iter()
1716 .map(|(back_ident, _, _)| back_ident);
1717 let defer_tick_buf_idents = back_buffer_idents_laziness
1719 .iter()
1720 .map(|(_, buf_ident, _)| buf_ident);
1721 let non_lazy_buf_idents = back_buffer_idents_laziness
1725 .iter()
1726 .filter_map(|(_, buf_ident, is_lazy)| (!is_lazy).then_some(buf_ident));
1727
1728 Ok(quote! {
1731 {
1732 #prefix
1733
1734 use #root::{var_expr, var_args};
1735
1736 let __dfir_wake_state = ::std::sync::Arc::new(
1737 #root::scheduled::context::WakeState::default()
1738 );
1739
1740 let __dfir_metrics = {
1741 let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1742 #( #metrics_init_code )*
1743 ::std::rc::Rc::new(dfir_metrics)
1744 };
1745
1746 #[allow(unused_mut)]
1747 let mut #df = #root::scheduled::context::Context::new(
1748 ::std::clone::Clone::clone(&__dfir_wake_state),
1749 __dfir_metrics,
1750 );
1751
1752 #( #op_prologue_code )*
1753
1754 #( let mut #back_buffer_idents = ::std::vec::Vec::new(); )*
1758 #( let mut #defer_tick_buf_idents = ::std::vec::Vec::new(); )*
1759
1760 let mut #bump_ident = #root::bumpalo::Bump::new();
1762
1763 let mut __dfir_work_done = true;
1768 #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1769 let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1770 #bump_ident.reset();
1772
1773 {
1774 let __dfir_metrics = #df.metrics();
1775
1776 #( #subgraph_blocks )*
1777
1778 if false #( || !#non_lazy_buf_idents.is_empty() )* {
1781 #df.schedule_subgraph(true);
1782 }
1783
1784 #( #back_edge_swap_code )*
1787 }
1788
1789 #( #op_tick_end_code )*
1791
1792 #df.__end_tick();
1793
1794 ::std::mem::take(&mut __dfir_work_done)
1795 };
1796 #root::scheduled::context::Dfir::new(
1797 __dfir_inline_tick,
1798 #df,
1799 #meta_graph_arg,
1800 #diagnostics_arg,
1801 )
1802 }
1803 })
1804 }
1805
1806 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1809 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1810 .node_ids()
1811 .filter_map(|node_id| {
1812 let op_color = self.node_color(node_id)?;
1813 Some((node_id, op_color))
1814 })
1815 .collect();
1816
1817 for sg_nodes in self.subgraph_nodes.values() {
1819 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1820
1821 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1822 let is_pull = idx < pull_to_push_idx;
1823 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1824 }
1825 }
1826
1827 node_color_map
1828 }
1829
1830 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1832 let mut output = String::new();
1833 self.write_mermaid(&mut output, write_config).unwrap();
1834 output
1835 }
1836
1837 pub fn write_mermaid(
1839 &self,
1840 output: impl std::fmt::Write,
1841 write_config: &WriteConfig,
1842 ) -> std::fmt::Result {
1843 let mut graph_write = Mermaid::new(output);
1844 self.write_graph(&mut graph_write, write_config)
1845 }
1846
1847 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1849 let mut output = String::new();
1850 let mut graph_write = Dot::new(&mut output);
1851 self.write_graph(&mut graph_write, write_config).unwrap();
1852 output
1853 }
1854
1855 pub fn write_dot(
1857 &self,
1858 output: impl std::fmt::Write,
1859 write_config: &WriteConfig,
1860 ) -> std::fmt::Result {
1861 let mut graph_write = Dot::new(output);
1862 self.write_graph(&mut graph_write, write_config)
1863 }
1864
1865 pub(crate) fn write_graph<W>(
1867 &self,
1868 mut graph_write: W,
1869 write_config: &WriteConfig,
1870 ) -> Result<(), W::Err>
1871 where
1872 W: GraphWrite,
1873 {
1874 fn helper_edge_label(
1875 src_port: &PortIndexValue,
1876 dst_port: &PortIndexValue,
1877 ) -> Option<String> {
1878 let src_label = match src_port {
1879 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1880 PortIndexValue::Int(index) => Some(index.value.to_string()),
1881 _ => None,
1882 };
1883 let dst_label = match dst_port {
1884 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1885 PortIndexValue::Int(index) => Some(index.value.to_string()),
1886 _ => None,
1887 };
1888 let label = match (src_label, dst_label) {
1889 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1890 (Some(l1), None) => Some(l1),
1891 (None, Some(l2)) => Some(l2),
1892 (None, None) => None,
1893 };
1894 label
1895 }
1896
1897 let node_color_map = self.node_color_map();
1899
1900 graph_write.write_prologue()?;
1902
1903 let mut skipped_handoffs = BTreeSet::new();
1905 for (node_id, node) in self.nodes() {
1906 if matches!(node, GraphNode::Handoff { .. }) && write_config.no_handoffs {
1907 skipped_handoffs.insert(node_id);
1908 continue;
1909 }
1910 graph_write.write_node_definition(
1911 node_id,
1912 &if write_config.op_short_text {
1913 node.to_name_string()
1914 } else if write_config.op_text_no_imports {
1915 let full_text = node.to_pretty_string();
1917 let mut output = String::new();
1918 for sentence in full_text.split('\n') {
1919 if sentence.trim().starts_with("use") {
1920 continue;
1921 }
1922 output.push('\n');
1923 output.push_str(sentence);
1924 }
1925 output.into()
1926 } else {
1927 node.to_pretty_string()
1928 },
1929 if write_config.no_pull_push {
1930 None
1931 } else {
1932 node_color_map.get(node_id).copied()
1933 },
1934 )?;
1935 }
1936
1937 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1939 if skipped_handoffs.contains(&src_id) {
1941 continue;
1942 }
1943
1944 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1945 if skipped_handoffs.contains(&dst_id) {
1946 let mut handoff_succs = self.node_successors(dst_id);
1950 if handoff_succs.len() == 0 {
1951 continue;
1952 }
1953 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1954 dst_id = succ_node;
1955 dst_port = self.edge_ports(succ_edge).1;
1956 }
1957
1958 let label = helper_edge_label(src_port, dst_port);
1959 let delay_type = self
1960 .node_op_inst(dst_id)
1961 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1962 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1963 }
1964
1965 if !write_config.no_references {
1967 for dst_id in self.node_ids() {
1968 for src_ref_id in self
1969 .node_handoff_references(dst_id)
1970 .iter()
1971 .filter_map(|r| r.node_id)
1972 {
1973 let resolved_src = if skipped_handoffs.contains(&src_ref_id) {
1976 self.node_predecessor_nodes(src_ref_id).next()
1977 } else {
1978 Some(src_ref_id)
1979 };
1980 let Some(resolved_src) = resolved_src else {
1981 continue;
1982 };
1983 let label = None;
1984 graph_write.write_edge(resolved_src, dst_id, None, label, true)?;
1985 }
1986 }
1987 }
1988
1989 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1997 let loop_id = if write_config.no_loops {
1998 None
1999 } else {
2000 self.subgraph_loop(sg_id)
2001 };
2002 (loop_id, sg_id)
2003 });
2004 let loop_subgraphs = into_group_map(loop_subgraphs);
2005 for (loop_id, subgraph_ids) in loop_subgraphs {
2006 if let Some(loop_id) = loop_id {
2007 graph_write.write_loop_start(loop_id)?;
2008 }
2009
2010 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
2012 self.subgraph(sg_id).iter().copied().map(move |node_id| {
2013 let opt_sg_id = if write_config.no_subgraphs {
2014 None
2015 } else {
2016 Some(sg_id)
2017 };
2018 (opt_sg_id, (self.node_varname(node_id), node_id))
2019 })
2020 });
2021 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
2022 for (sg_id, varnames) in subgraph_varnames_nodes {
2023 if let Some(sg_id) = sg_id {
2024 graph_write.write_subgraph_start(sg_id)?;
2025 }
2026
2027 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
2029 let varname = if write_config.no_varnames {
2030 None
2031 } else {
2032 varname
2033 };
2034 (varname, node)
2035 });
2036 let varname_nodes = into_group_map(varname_nodes);
2037 for (varname, node_ids) in varname_nodes {
2038 if let Some(varname) = varname {
2039 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
2040 }
2041
2042 for node_id in node_ids {
2044 graph_write.write_node(node_id)?;
2045 }
2046
2047 if varname.is_some() {
2048 graph_write.write_varname_end()?;
2049 }
2050 }
2051
2052 if sg_id.is_some() {
2053 graph_write.write_subgraph_end()?;
2054 }
2055 }
2056
2057 if loop_id.is_some() {
2058 graph_write.write_loop_end()?;
2059 }
2060 }
2061
2062 graph_write.write_epilogue()?;
2064
2065 Ok(())
2066 }
2067
2068 pub fn surface_syntax_string(&self) -> String {
2070 let mut string = String::new();
2071 self.write_surface_syntax(&mut string).unwrap();
2072 string
2073 }
2074
2075 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2077 for (key, node) in self.nodes.iter() {
2078 match node {
2079 GraphNode::Operator(op) => {
2080 writeln!(write, "_{:?} = {};", key.data(), op.to_token_stream())?;
2081 }
2082 GraphNode::Handoff {
2083 kind: HandoffKind::Vec,
2084 ..
2085 } => {
2086 writeln!(write, "_{:?} = handoff();", key.data())?;
2087 }
2088 GraphNode::Handoff {
2089 kind: HandoffKind::Singleton,
2090 ..
2091 } => {
2092 writeln!(write, "_{:?} = singleton();", key.data())?;
2093 }
2094 GraphNode::Handoff {
2095 kind: HandoffKind::Optional,
2096 ..
2097 } => {
2098 writeln!(write, "_{:?} = optional();", key.data())?;
2099 }
2100 GraphNode::ModuleBoundary { .. } => panic!(),
2101 }
2102 }
2103 writeln!(write)?;
2104 for (e, (src_key, dst_key)) in self.graph.edges() {
2105 let (src_port, dst_port) = self.edge_ports(e);
2106 let src_port_str = if src_port.is_specified() {
2107 format!("[{}]", src_port)
2108 } else {
2109 String::new()
2110 };
2111 let dst_port_str = if dst_port.is_specified() {
2112 format!("[{}]", dst_port)
2113 } else {
2114 String::new()
2115 };
2116 writeln!(
2117 write,
2118 "_{:?}{} -> {}_{:?};",
2119 src_key.data(),
2120 src_port_str,
2121 dst_port_str,
2122 dst_key.data()
2123 )?;
2124 }
2125 Ok(())
2126 }
2127
2128 pub fn mermaid_string_flat(&self) -> String {
2130 let mut string = String::new();
2131 self.write_mermaid_flat(&mut string).unwrap();
2132 string
2133 }
2134
2135 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2137 writeln!(write, "flowchart TB")?;
2138 for (key, node) in self.nodes.iter() {
2139 match node {
2140 GraphNode::Operator(operator) => writeln!(
2141 write,
2142 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
2143 span = PrettySpan(node.span()),
2144 id = key.data(),
2145 row_col = PrettyRowCol(node.span()),
2146 code = operator
2147 .to_token_stream()
2148 .to_string()
2149 .replace('&', "&")
2150 .replace('<', "<")
2151 .replace('>', ">")
2152 .replace('"', """)
2153 .replace('\n', "<br>"),
2154 ),
2155 GraphNode::Handoff {
2156 kind: HandoffKind::Vec,
2157 ..
2158 } => {
2159 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
2160 }
2161 GraphNode::Handoff {
2162 kind: HandoffKind::Singleton | HandoffKind::Optional,
2163 ..
2164 } => {
2165 writeln!(
2166 write,
2167 r#" {:?}{{"{}"}}"#,
2168 key.data(),
2169 SINGLETON_SLOT_NODE_STR
2170 )
2171 }
2172 GraphNode::ModuleBoundary { .. } => {
2173 writeln!(
2174 write,
2175 r#" {:?}{{"{}"}}"#,
2176 key.data(),
2177 MODULE_BOUNDARY_NODE_STR
2178 )
2179 }
2180 }?;
2181 }
2182 writeln!(write)?;
2183 for (_e, (src_key, dst_key)) in self.graph.edges() {
2184 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
2185 }
2186 Ok(())
2187 }
2188}
2189
2190impl DfirGraph {
2192 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
2194 self.loop_nodes.keys()
2195 }
2196
2197 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
2199 self.loop_nodes.iter()
2200 }
2201
2202 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
2204 let loop_id = self.loop_nodes.insert(Vec::new());
2205 self.loop_children.insert(loop_id, Vec::new());
2206 if let Some(parent_loop) = parent_loop {
2207 self.loop_parent.insert(loop_id, parent_loop);
2208 self.loop_children
2209 .get_mut(parent_loop)
2210 .unwrap()
2211 .push(loop_id);
2212 } else {
2213 self.root_loops.push(loop_id);
2214 }
2215 loop_id
2216 }
2217
2218 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
2220 self.node_loops.get(node_id).copied()
2221 }
2222
2223 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2225 let &node_id = self.subgraph(subgraph_id).first().unwrap();
2226 let out = self.node_loop(node_id);
2227 debug_assert!(
2228 self.subgraph(subgraph_id)
2229 .iter()
2230 .all(|&node_id| self.node_loop(node_id) == out),
2231 "Subgraph nodes should all have the same loop context."
2232 );
2233 out
2234 }
2235
2236 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2238 self.loop_parent.get(loop_id).copied()
2239 }
2240
2241 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2243 self.loop_children.get(loop_id).unwrap()
2244 }
2245}
2246
2247#[derive(Clone, Debug, Default)]
2249#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2250pub struct WriteConfig {
2251 #[cfg_attr(feature = "clap-derive", arg(long))]
2253 pub no_subgraphs: bool,
2254 #[cfg_attr(feature = "clap-derive", arg(long))]
2256 pub no_varnames: bool,
2257 #[cfg_attr(feature = "clap-derive", arg(long))]
2259 pub no_pull_push: bool,
2260 #[cfg_attr(feature = "clap-derive", arg(long))]
2262 pub no_handoffs: bool,
2263 #[cfg_attr(feature = "clap-derive", arg(long))]
2265 pub no_references: bool,
2266 #[cfg_attr(feature = "clap-derive", arg(long))]
2268 pub no_loops: bool,
2269
2270 #[cfg_attr(feature = "clap-derive", arg(long))]
2272 pub op_short_text: bool,
2273 #[cfg_attr(feature = "clap-derive", arg(long))]
2275 pub op_text_no_imports: bool,
2276}
2277
2278#[derive(Copy, Clone, Debug)]
2280#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2281pub enum WriteGraphType {
2282 Mermaid,
2284 Dot,
2286}
2287
2288fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2290where
2291 K: Ord,
2292{
2293 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2294 for (k, v) in iter {
2295 out.entry(k).or_default().push(v);
2296 }
2297 out
2298}