1use tokio::sync::{broadcast, mpsc};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ShutdownReason {
12 Graceful,
14 Immediate,
16 Reload,
18}
19
20#[derive(Debug)]
22pub struct ShutdownSignal {
23 rx: broadcast::Receiver<ShutdownReason>,
24 _ack: mpsc::Sender<()>,
25}
26
27impl Clone for ShutdownSignal {
28 fn clone(&self) -> Self {
29 Self {
30 rx: self.rx.resubscribe(),
31 _ack: self._ack.clone(),
32 }
33 }
34}
35
36impl ShutdownSignal {
37 pub async fn recv(&mut self) -> Option<ShutdownReason> {
39 self.rx.recv().await.ok()
40 }
41
42 pub fn is_shutdown(&self) -> bool {
44 !self.rx.is_empty()
45 }
46}
47
48#[derive(Debug)]
50pub struct ShutdownCoordinator {
51 tx: broadcast::Sender<ShutdownReason>,
52 ack_tx: mpsc::Sender<()>,
54 ack_rx: mpsc::Receiver<()>,
55}
56
57impl Default for ShutdownCoordinator {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl ShutdownCoordinator {
64 pub fn new() -> Self {
66 let (tx, _) = broadcast::channel(1);
67 let (ack_tx, ack_rx) = mpsc::channel(1);
68 Self { tx, ack_tx, ack_rx }
69 }
70
71 pub fn register(&self) -> ShutdownSignal {
74 ShutdownSignal {
75 rx: self.tx.subscribe(),
76 _ack: self.ack_tx.clone(),
77 }
78 }
79
80 pub fn broadcast(&self, reason: ShutdownReason) {
82 let _ = self.tx.send(reason);
83 }
84
85 pub async fn wait(mut self) {
90 drop(self.ack_tx);
92
93 while self.ack_rx.recv().await.is_some() {}
95 }
96}
97
98#[cfg(test)]
99#[allow(clippy::unwrap_used)]
100mod tests {
101 use super::*;
102 use std::time::Duration;
103
104 #[tokio::test]
105 async fn test_graceful_shutdown() {
106 let coord = ShutdownCoordinator::new();
107 let mut sig1 = coord.register();
108 let mut sig2 = coord.register();
109
110 let t1 = tokio::spawn(async move {
111 assert_eq!(sig1.recv().await, Some(ShutdownReason::Graceful));
112 tokio::time::sleep(Duration::from_millis(10)).await;
114 });
115
116 let t2 = tokio::spawn(async move {
117 assert_eq!(sig2.recv().await, Some(ShutdownReason::Graceful));
118 });
119
120 coord.broadcast(ShutdownReason::Graceful);
121
122 coord.wait().await;
124
125 t1.await.unwrap();
126 t2.await.unwrap();
127 }
128}