1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
121 pub write_prologue_after: TokenStream,
124 pub write_iterator: TokenStream,
131 pub write_iterator_after: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
170 where
171 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
172 {
173 pull
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
183 where
184 Psh: #root::dfir_pipes::push::Push<Item, ()>,
185 {
186 push
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::dfir_pipes::pull::poll_fn({
227 #(
228 let mut #inputs = ::std::boxed::Box::pin(#inputs);
229 )*
230 move |_cx| {
231 #(
235 let #inputs = #root::dfir_pipes::pull::Pull::pull(
236 ::std::pin::Pin::as_mut(&mut #inputs),
237 <_ as #root::dfir_pipes::Context>::from_task(_cx),
238 );
239 )*
240 #(
241 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
242 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
243 }
244 )*
245 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
246 }
247 });
248 }
249 } else {
250 quote_spanned! {op_span=>
251 #[allow(clippy::let_unit_value)]
252 let _ = (#(#outputs),*);
253 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
254 }
255 }
256}
257
258pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
261 let write_iterator = null_write_iterator_fn(write_context_args);
262 Ok(OperatorWriteOutput {
263 write_iterator,
264 ..Default::default()
265 })
266};
267
268macro_rules! declare_ops {
269 ( $( $mod:ident :: $op:ident, )* ) => {
270 $( pub(crate) mod $mod; )*
271 pub const OPERATORS: &[OperatorConstraints] = &[
273 $( $mod :: $op, )*
274 ];
275 };
276}
277declare_ops![
278 all_iterations::ALL_ITERATIONS,
279 all_once::ALL_ONCE,
280 anti_join::ANTI_JOIN,
281 assert::ASSERT,
282 assert_eq::ASSERT_EQ,
283 batch::BATCH,
284 chain::CHAIN,
285 chain_first_n::CHAIN_FIRST_N,
286 _counter::_COUNTER,
287 cross_join::CROSS_JOIN,
288 cross_join_multiset::CROSS_JOIN_MULTISET,
289 cross_singleton::CROSS_SINGLETON,
290 demux_enum::DEMUX_ENUM,
291 dest_file::DEST_FILE,
292 dest_sink::DEST_SINK,
293 dest_sink_serde::DEST_SINK_SERDE,
294 difference::DIFFERENCE,
295 enumerate::ENUMERATE,
296 filter::FILTER,
297 filter_map::FILTER_MAP,
298 flat_map::FLAT_MAP,
299 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
300 flatten::FLATTEN,
301 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
302 fold::FOLD,
303 fold_no_replay::FOLD_NO_REPLAY,
304 for_each::FOR_EACH,
305 identity::IDENTITY,
306 initialize::INITIALIZE,
307 inspect::INSPECT,
308 join::JOIN,
309 join_fused::JOIN_FUSED,
310 join_fused_lhs::JOIN_FUSED_LHS,
311 join_fused_rhs::JOIN_FUSED_RHS,
312 join_multiset::JOIN_MULTISET,
313 join_multiset_half::JOIN_MULTISET_HALF,
314 fold_keyed::FOLD_KEYED,
315 reduce_keyed::REDUCE_KEYED,
316 repeat_n::REPEAT_N,
317 lattice_bimorphism::LATTICE_BIMORPHISM,
319 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
320 lattice_fold::LATTICE_FOLD,
321 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
322 lattice_reduce::LATTICE_REDUCE,
323 map::MAP,
324 union::UNION,
325 multiset_delta::MULTISET_DELTA,
326 next_iteration::NEXT_ITERATION,
327 defer_signal::DEFER_SIGNAL,
328 defer_tick::DEFER_TICK,
329 defer_tick_lazy::DEFER_TICK_LAZY,
330 null::NULL,
331 partition::PARTITION,
332 persist::PERSIST,
333 persist_mut::PERSIST_MUT,
334 persist_mut_keyed::PERSIST_MUT_KEYED,
335 prefix::PREFIX,
336 resolve_futures::RESOLVE_FUTURES,
337 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
338 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
339 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
340 reduce::REDUCE,
341 reduce_no_replay::REDUCE_NO_REPLAY,
342 scan::SCAN,
343 scan_async_blocking::SCAN_ASYNC_BLOCKING,
344 spin::SPIN,
345 sort::SORT,
346 sort_by_key::SORT_BY_KEY,
347 source_file::SOURCE_FILE,
348 source_interval::SOURCE_INTERVAL,
349 source_iter::SOURCE_ITER,
350 source_json::SOURCE_JSON,
351 source_stdin::SOURCE_STDIN,
352 source_stream::SOURCE_STREAM,
353 source_stream_serde::SOURCE_STREAM_SERDE,
354 state::STATE,
355 state_by::STATE_BY,
356 tee::TEE,
357 unique::UNIQUE,
358 unzip::UNZIP,
359 zip::ZIP,
360 zip_longest::ZIP_LONGEST,
361];
362
363pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
365 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
366 OnceLock::new();
367 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
368}
369pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
371 if let GraphNode::Operator(operator) = node {
372 find_op_op_constraints(operator)
373 } else {
374 None
375 }
376}
377pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
379 let name = &*operator.name_string();
380 operator_lookup().get(name).copied()
381}
382
383#[derive(Clone)]
385pub struct WriteContextArgs<'a> {
386 pub root: &'a TokenStream,
388 pub context: &'a Ident,
391 pub df_ident: &'a Ident,
395 pub subgraph_id: GraphSubgraphId,
397 pub node_id: GraphNodeId,
399 pub loop_id: Option<GraphLoopId>,
401 pub op_span: Span,
403 pub op_tag: Option<String>,
405 pub work_fn: &'a Ident,
407 pub work_fn_async: &'a Ident,
409
410 pub ident: &'a Ident,
412 pub is_pull: bool,
414 pub inputs: &'a [Ident],
416 pub outputs: &'a [Ident],
418 pub singleton_output_ident: &'a Ident,
420
421 pub op_name: &'static str,
423 pub op_inst: &'a OperatorInstance,
425 pub arguments: &'a Punctuated<Expr, Token![,]>,
431 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
433}
434impl WriteContextArgs<'_> {
435 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
441 Ident::new(
442 &format!(
443 "sg_{:?}_node_{:?}_{}",
444 self.subgraph_id.data(),
445 self.node_id.data(),
446 suffix.as_ref(),
447 ),
448 self.op_span,
449 )
450 }
451
452 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
455 let root = self.root;
456 let variant =
457 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
458 Some(quote_spanned! {self.op_span=>
459 #root::scheduled::StateLifespan::#variant
460 })
461 }
462
463 pub fn persistence_args_disallow_mutable<const N: usize>(
465 &self,
466 diagnostics: &mut Diagnostics,
467 ) -> [Persistence; N] {
468 let len = self.op_inst.generics.persistence_args.len();
469 if 0 != len && 1 != len && N != len {
470 diagnostics.push(Diagnostic::spanned(
471 self.op_span,
472 Level::Error,
473 format!(
474 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
475 self.op_name, N
476 ),
477 ));
478 }
479
480 let default_persistence = if self.loop_id.is_some() {
481 Persistence::None
482 } else {
483 Persistence::Tick
484 };
485 let mut out = [default_persistence; N];
486 self.op_inst
487 .generics
488 .persistence_args
489 .iter()
490 .copied()
491 .cycle() .take(N)
493 .enumerate()
494 .filter(|&(_i, p)| {
495 if p == Persistence::Mutable {
496 diagnostics.push(Diagnostic::spanned(
497 self.op_span,
498 Level::Error,
499 format!(
500 "An implementation of `'{}` does not exist",
501 p.to_str_lowercase()
502 ),
503 ));
504 false
505 } else {
506 true
507 }
508 })
509 .for_each(|(i, p)| {
510 out[i] = p;
511 });
512 out
513 }
514}
515
516pub trait RangeTrait<T>: Send + Sync + Debug
518where
519 T: ?Sized,
520{
521 fn start_bound(&self) -> Bound<&T>;
523 fn end_bound(&self) -> Bound<&T>;
525 fn contains(&self, item: &T) -> bool
527 where
528 T: PartialOrd<T>;
529
530 fn human_string(&self) -> String
532 where
533 T: Display + PartialEq,
534 {
535 match (self.start_bound(), self.end_bound()) {
536 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
537
538 (Bound::Included(n), Bound::Included(x)) if n == x => {
539 format!("exactly {}", n)
540 }
541 (Bound::Included(n), Bound::Included(x)) => {
542 format!("at least {} and at most {}", n, x)
543 }
544 (Bound::Included(n), Bound::Excluded(x)) => {
545 format!("at least {} and less than {}", n, x)
546 }
547 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
548 (Bound::Excluded(n), Bound::Included(x)) => {
549 format!("more than {} and at most {}", n, x)
550 }
551 (Bound::Excluded(n), Bound::Excluded(x)) => {
552 format!("more than {} and less than {}", n, x)
553 }
554 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
555 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
556 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
557 }
558 }
559}
560
561impl<R, T> RangeTrait<T> for R
562where
563 R: RangeBounds<T> + Send + Sync + Debug,
564{
565 fn start_bound(&self) -> Bound<&T> {
566 self.start_bound()
567 }
568
569 fn end_bound(&self) -> Bound<&T> {
570 self.end_bound()
571 }
572
573 fn contains(&self, item: &T) -> bool
574 where
575 T: PartialOrd<T>,
576 {
577 self.contains(item)
578 }
579}
580
581#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
583pub enum Persistence {
584 None,
586 Loop,
588 Tick,
590 Static,
592 Mutable,
594}
595impl Persistence {
596 pub fn as_state_lifespan_variant(
598 self,
599 subgraph_id: GraphSubgraphId,
600 loop_id: Option<GraphLoopId>,
601 span: Span,
602 ) -> Option<TokenStream> {
603 match self {
604 Persistence::None => {
605 let sg_ident = subgraph_id.as_ident(span);
606 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
607 }
608 Persistence::Loop => {
609 let loop_ident = loop_id
610 .expect("`Persistence::Loop` outside of a loop context.")
611 .as_ident(span);
612 Some(quote_spanned!(span=> Loop(#loop_ident)))
613 }
614 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
615 Persistence::Static => None,
616 Persistence::Mutable => None,
617 }
618 }
619
620 pub fn to_str_lowercase(self) -> &'static str {
622 match self {
623 Persistence::None => "none",
624 Persistence::Tick => "tick",
625 Persistence::Loop => "loop",
626 Persistence::Static => "static",
627 Persistence::Mutable => "mutable",
628 }
629 }
630}
631
632fn make_missing_runtime_msg(op_name: &str) -> Literal {
634 Literal::string(&format!(
635 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
636 op_name
637 ))
638}
639
640#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
642pub enum OperatorCategory {
643 Map,
645 Filter,
647 Flatten,
649 Fold,
651 KeyedFold,
653 LatticeFold,
655 Persistence,
657 MultiIn,
659 MultiOut,
661 Source,
663 Sink,
665 Control,
667 CompilerFusionOperator,
669 Windowing,
671 Unwindowing,
673}
674impl OperatorCategory {
675 pub fn name(self) -> &'static str {
677 self.get_variant_docs().split_once(":").unwrap().0
678 }
679 pub fn description(self) -> &'static str {
681 self.get_variant_docs().split_once(":").unwrap().1
682 }
683}
684
685#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
687pub enum FloType {
688 Source,
690 Windowing,
692 Unwindowing,
694 NextIteration,
696}