Skip to main content

kinetic_core/
shutdown.rs

1//! Graceful shutdown coordination for Kinetic components.
2//!
3//! Uses a token-based approach where sources and long-running tasks hold
4//! a shutdown token. When the coordinator initiates shutdown, tokens are cancelled
5//! and components have a graceful period to drain before forced termination.
6
7use tokio::sync::{broadcast, mpsc};
8
9/// The type of shutdown signal received.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ShutdownReason {
12    /// A graceful shutdown was requested (e.g. SIGINT/SIGTERM).
13    Graceful,
14    /// An immediate shutdown was requested, bypassing drain (e.g. SIGQUIT).
15    Immediate,
16    /// Configuration is reloading, components might be stopped to be replaced.
17    Reload,
18}
19
20/// A handle held by a component that can be used to listen for shutdown requests.
21#[derive(Debug)]
22pub struct ShutdownSignal {
23    rx: broadcast::Receiver<ShutdownReason>,
24    _ack: mpsc::Sender<()>,
25}
26
27impl Clone for ShutdownSignal {
28    fn clone(&self) -> Self {
29        Self {
30            rx: self.rx.resubscribe(),
31            _ack: self._ack.clone(),
32        }
33    }
34}
35
36impl ShutdownSignal {
37    /// Wait for a shutdown signal.
38    pub async fn recv(&mut self) -> Option<ShutdownReason> {
39        self.rx.recv().await.ok()
40    }
41
42    /// Returns true if a shutdown signal has been received.
43    pub fn is_shutdown(&self) -> bool {
44        !self.rx.is_empty()
45    }
46}
47
48/// Coordinates the shutdown of multiple components.
49#[derive(Debug)]
50pub struct ShutdownCoordinator {
51    tx: broadcast::Sender<ShutdownReason>,
52    // Channel used to wait for all components to drop their ShutdownSignal
53    ack_tx: mpsc::Sender<()>,
54    ack_rx: mpsc::Receiver<()>,
55}
56
57impl Default for ShutdownCoordinator {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl ShutdownCoordinator {
64    /// Create a new shutdown coordinator.
65    pub fn new() -> Self {
66        let (tx, _) = broadcast::channel(1);
67        let (ack_tx, ack_rx) = mpsc::channel(1);
68        Self { tx, ack_tx, ack_rx }
69    }
70
71    /// Register a new component to be managed by this coordinator.
72    /// Returns a `ShutdownSignal` that the component should hold and listen to.
73    pub fn register(&self) -> ShutdownSignal {
74        ShutdownSignal {
75            rx: self.tx.subscribe(),
76            _ack: self.ack_tx.clone(),
77        }
78    }
79
80    /// Broadcast a shutdown signal to all registered components.
81    pub fn broadcast(&self, reason: ShutdownReason) {
82        let _ = self.tx.send(reason);
83    }
84
85    /// Wait for all registered components to finish shutting down.
86    ///
87    /// This consumes the coordinator, dropping the `ack_tx` sender so that
88    /// the `ack_rx` receiver will return `None` once all `ShutdownSignal`s are dropped.
89    pub async fn wait(mut self) {
90        // Drop our own sender so the receiver can complete when children drop theirs
91        drop(self.ack_tx);
92
93        // Wait for all acks to drop
94        while self.ack_rx.recv().await.is_some() {}
95    }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101    use super::*;
102    use std::time::Duration;
103
104    #[tokio::test]
105    async fn test_graceful_shutdown() {
106        let coord = ShutdownCoordinator::new();
107        let mut sig1 = coord.register();
108        let mut sig2 = coord.register();
109
110        let t1 = tokio::spawn(async move {
111            assert_eq!(sig1.recv().await, Some(ShutdownReason::Graceful));
112            // Simulate cleanup work
113            tokio::time::sleep(Duration::from_millis(10)).await;
114        });
115
116        let t2 = tokio::spawn(async move {
117            assert_eq!(sig2.recv().await, Some(ShutdownReason::Graceful));
118        });
119
120        coord.broadcast(ShutdownReason::Graceful);
121
122        // Wait for components to finish
123        coord.wait().await;
124
125        t1.await.unwrap();
126        t2.await.unwrap();
127    }
128}