two-way file system sync
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

client.rs 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. use std::collections::BTreeMap;
  2. use std::time::Duration;
  3. use std::{fmt, mem};
  4. use core::pin::Pin;
  5. use futures::future::Future;
  6. use futures::sink::{Sink, SinkExt};
  7. use futures::stream::{Stream, StreamExt};
  8. use futures::task::{Context, Poll};
  9. use log::error;
  10. use pin_project::pin_project;
  11. use rmps::Serializer;
  12. use serde::de::DeserializeOwned;
  13. use serde::Serialize;
  14. use tokio::sync::{mpsc, oneshot};
  15. use tokio::time::{delay_for, Delay};
  16. use super::idalloc::IdAllocator;
  17. use super::{Packet, SERVER_EVENT_ID};
  18. pub fn create<StreamType, ErrorType, PacketType>(
  19. stream: StreamType,
  20. ) -> (
  21. RPCInterface<PacketType, ErrorType>,
  22. ServerEventStream<PacketType>,
  23. )
  24. where
  25. StreamType: Stream<Item = Result<Vec<u8>, ErrorType>>
  26. + Sink<Vec<u8>, Error = ErrorType>
  27. + Unpin
  28. + Send
  29. + 'static,
  30. PacketType: DeserializeOwned + Serialize + Send + 'static,
  31. ErrorType: fmt::Debug + Send + 'static,
  32. {
  33. // TODO: Does this channel have to be unbounded? We can just drop the sender, and the receiver
  34. // will notice.
  35. let (cmd_send, cmd_receive) = mpsc::unbounded_channel();
  36. let (event_send, event_receive) = mpsc::channel(1);
  37. tokio::spawn(async move {
  38. let task = ClientConnectionTask::new(stream, cmd_receive, event_send);
  39. task.run().await;
  40. });
  41. (RPCInterface { cmd_send }, Box::pin(event_receive))
  42. }
  43. enum ClientTaskCommand<PacketType, ErrorType> {
  44. Close,
  45. SendPacket(CallData<PacketType, ErrorType>),
  46. }
  47. struct ClientConnectionTask<StreamType, ErrorType, PacketType> {
  48. free_ids: IdAllocator,
  49. calls_in_progress: BTreeMap<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>,
  50. stream: StreamType,
  51. cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
  52. event_send: mpsc::Sender<PacketType>,
  53. }
  54. impl<StreamType, ErrorType, PacketType> ClientConnectionTask<StreamType, ErrorType, PacketType>
  55. where
  56. PacketType: DeserializeOwned + Serialize,
  57. StreamType:
  58. Stream<Item = Result<Vec<u8>, ErrorType>> + Sink<Vec<u8>, Error = ErrorType> + Unpin,
  59. ErrorType: fmt::Debug,
  60. {
  61. fn new(
  62. stream: StreamType,
  63. cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
  64. event_send: mpsc::Sender<PacketType>,
  65. ) -> Self {
  66. Self {
  67. free_ids: IdAllocator::new(),
  68. calls_in_progress: BTreeMap::new(),
  69. stream: stream,
  70. cmd_receive: cmd_receive,
  71. event_send,
  72. }
  73. }
  74. async fn run(mut self) {
  75. let mut connection_closed = false;
  76. loop {
  77. tokio::select! {
  78. packet = &mut self.stream.next() => {
  79. match packet {
  80. Some(Ok(data)) => {
  81. self.received_packet(&data).await;
  82. }
  83. Some(Err(e)) => {
  84. error!("Network error: {:?}", e);
  85. }
  86. None => {
  87. // Stream closed. We still need to return errors for any
  88. // packet sent by the client, so do not exit the loop
  89. // immediately but rather return an error for any
  90. // incoming packet.
  91. connection_closed = true;
  92. break;
  93. }
  94. }
  95. }
  96. cmd = &mut self.cmd_receive.next() => {
  97. match cmd {
  98. Some(ClientTaskCommand::Close) => {
  99. // The RPC interface for this connection was dropped, so
  100. // we just exit the loop and close the connection.
  101. break;
  102. }
  103. Some(ClientTaskCommand::SendPacket(call)) => {
  104. self.send_packet(call).await;
  105. }
  106. None => {
  107. // This should never happen.
  108. break;
  109. }
  110. }
  111. }
  112. }
  113. }
  114. drop(self.event_send);
  115. // Return errors for all pending packets as there will be no responses.
  116. let mut calls_in_progress = BTreeMap::new();
  117. mem::swap(&mut calls_in_progress, &mut self.calls_in_progress);
  118. for (_, result) in calls_in_progress {
  119. result.send(Err(RPCError::Closed)).ok();
  120. }
  121. if connection_closed {
  122. // The stream was closed, but the command channel was not. We still need
  123. // to return errors for any packet sent by the client, so do not exit
  124. // the loop immediately but rather return an error for any incoming
  125. // packet.
  126. client_connection_closed(self.cmd_receive).await;
  127. } else {
  128. // The command channel was closed, but the stream was not. Gracefully
  129. // close the stream.
  130. // TODO
  131. }
  132. }
  133. /// Sends a packet.
  134. ///
  135. /// # Errors
  136. ///
  137. /// If the connection is closed, the function does *not* return `RPCError::Closed`, but rather
  138. /// returns the corresponding stream error.
  139. async fn send_packet(&mut self, call: CallData<PacketType, ErrorType>) {
  140. // Add an ID for the outgoing packet.
  141. let id = self.free_ids.alloc();
  142. let packet = Packet {
  143. id: id,
  144. payload: call.data,
  145. };
  146. let mut serialized = Vec::new();
  147. packet
  148. .serialize(
  149. &mut Serializer::new(&mut serialized)
  150. .with_struct_map()
  151. .with_string_variants(),
  152. )
  153. .unwrap();
  154. // Send the serialized packet.
  155. // TODO: Send without flushing?
  156. match self.stream.send(serialized).await {
  157. Ok(()) => {
  158. self.calls_in_progress.insert(id, call.result);
  159. }
  160. Err(e) => {
  161. // We potentially leak the ID here, but that is
  162. // okay as the state of the call is not well
  163. // defined anymore.
  164. call.result.send(Err(RPCError::Stream(e))).ok();
  165. }
  166. }
  167. }
  168. async fn received_packet(&mut self, data: &[u8]) {
  169. match rmps::decode::from_slice(data) as Result<Packet<PacketType>, _> {
  170. Ok(deserialized) => {
  171. if deserialized.id == SERVER_EVENT_ID {
  172. // We received an asynchronous event from the server.
  173. self.event_send.send(deserialized.payload).await.ok();
  174. } else {
  175. // We received a response to a call - find the call and forward
  176. // the reply.
  177. if let Some(call) = self.calls_in_progress.remove(&deserialized.id) {
  178. self.free_ids.free(deserialized.id);
  179. call.send(Ok(deserialized.payload)).ok();
  180. } else {
  181. error!(
  182. "Received a reply for an unknown call ({:?})!",
  183. deserialized.id
  184. );
  185. }
  186. }
  187. }
  188. Err(e) => {
  189. // We received an invalid packet. We cannot determine for which
  190. // call the packet was, so we just ignore it. The corresponding
  191. // call will freeze until the connection is dropped.
  192. error!("Invalid packet received: {:?}", e);
  193. }
  194. };
  195. }
  196. }
  197. async fn client_connection_closed<PacketType, ErrorType>(
  198. mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
  199. ) {
  200. loop {
  201. match cmd_receive.recv().await {
  202. Some(ClientTaskCommand::Close) => {
  203. // The connection was already closed.
  204. break;
  205. }
  206. Some(ClientTaskCommand::SendPacket(call)) => {
  207. // The connection was closed, so return an error for any
  208. // packet sent.
  209. call.result.send(Err(RPCError::Closed)).ok();
  210. }
  211. None => {
  212. // This should never happen.
  213. break;
  214. }
  215. }
  216. }
  217. }
  218. pub struct RPCInterface<PacketType, ErrorType> {
  219. cmd_send: mpsc::UnboundedSender<ClientTaskCommand<PacketType, ErrorType>>,
  220. }
  221. impl<PacketType, ErrorType> RPCInterface<PacketType, ErrorType> {
  222. pub fn call(&mut self, call: PacketType) -> RPC<PacketType, ErrorType> {
  223. let (sender, receiver) = oneshot::channel();
  224. let call_data = CallData {
  225. data: call,
  226. result: sender,
  227. };
  228. if let Err(_) = self.cmd_send.send(ClientTaskCommand::SendPacket(call_data)) {
  229. panic!("could not send packet to connection task");
  230. }
  231. RPC {
  232. result: receiver,
  233. timeout: None,
  234. }
  235. }
  236. }
  237. impl<PacketType, ErrorType> Drop for RPCInterface<PacketType, ErrorType> {
  238. fn drop(&mut self) {
  239. // Stop the connection loop and close the connection.
  240. // send() only fails if the other end has already been closed - we must
  241. // not unwrap() or expect() because then nested panics might occur.
  242. self.cmd_send.send(ClientTaskCommand::Close).ok();
  243. }
  244. }
  245. struct CallData<PacketType, ErrorType> {
  246. data: PacketType,
  247. result: oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>,
  248. }
  249. #[pin_project]
  250. pub struct RPC<PacketType, ErrorType> {
  251. #[pin]
  252. result: oneshot::Receiver<Result<PacketType, RPCError<ErrorType>>>,
  253. #[pin]
  254. timeout: Option<Delay>,
  255. }
  256. impl<PacketType, ErrorType> RPC<PacketType, ErrorType> {
  257. pub fn with_timeout(mut self, timeout: Duration) -> Self {
  258. self.timeout = Some(delay_for(timeout));
  259. self
  260. }
  261. }
  262. impl<PacketType, ErrorType> Future for RPC<PacketType, ErrorType> {
  263. type Output = Result<PacketType, RPCError<ErrorType>>;
  264. fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
  265. let this = self.project();
  266. // Try receiving from the receiver.
  267. match this.result.poll(cx) {
  268. Poll::Ready(Ok(result)) => return Poll::Ready(result),
  269. Poll::Ready(Err(_)) => {
  270. // The channel was closed.
  271. return Poll::Ready(Err(RPCError::Closed));
  272. }
  273. Poll::Pending => (),
  274. }
  275. // Nothing received, check the timeout instead.
  276. if this.timeout.is_some() {
  277. match this.timeout.as_pin_mut().unwrap().poll(cx) {
  278. Poll::Ready(()) => return Poll::Ready(Err(RPCError::Timeout)),
  279. Poll::Pending => (),
  280. }
  281. }
  282. Poll::Pending
  283. }
  284. }
  285. /// Stream of asynchronous packets received from the server.
  286. type ServerEventStream<PacketType> = Pin<Box<mpsc::Receiver<PacketType>>>;
  287. #[derive(Debug, PartialEq)]
  288. pub enum RPCError<StreamError> {
  289. Closed,
  290. Timeout,
  291. Stream(StreamError),
  292. }
  293. #[cfg(test)]
  294. mod tests {
  295. use super::*;
  296. use crate::network::test_utils::{panic_after, Pipe, PipeError};
  297. use crate::network::Packet;
  298. type TestPayload = u32;
  299. fn serialize_packet<T>(id: u32, payload: T) -> Vec<u8>
  300. where
  301. T: Serialize,
  302. {
  303. let packet = Packet { id, payload };
  304. let mut serialized = Vec::new();
  305. packet
  306. .serialize(
  307. &mut Serializer::new(&mut serialized)
  308. .with_struct_map()
  309. .with_string_variants(),
  310. )
  311. .unwrap();
  312. serialized
  313. }
  314. /// Test the basic RPC interface.
  315. #[tokio::test]
  316. async fn test_rpc_client() {
  317. panic_after(Duration::from_secs(1), async move {
  318. let (pipe1, mut pipe2) = Pipe::new();
  319. let (mut client, events) = create::<_, _, TestPayload>(pipe1);
  320. println!("init");
  321. // Client task.
  322. let tasks = (
  323. tokio::spawn(async move {
  324. assert_eq!(client.call(42).await, Ok(1337));
  325. let delayed = client.call(43);
  326. assert_eq!(client.call(44).await, Ok(1339));
  327. assert_eq!(delayed.await, Ok(1338));
  328. }),
  329. tokio::spawn(async move {
  330. // Invalid ID, should be ignored.
  331. pipe2.send(serialize_packet(42, 42)).await.unwrap();
  332. assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 42))));
  333. pipe2.send(serialize_packet(0, 1337)).await.unwrap();
  334. // Network errors shall be ignored if the stream is not closed.
  335. pipe2.inject_error();
  336. // The ID has to be reused, and a second concurrent call has to have a
  337. // different ID.
  338. assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 43))));
  339. assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(1, 44))));
  340. pipe2.send(serialize_packet(1, 1339)).await.unwrap();
  341. pipe2.send(serialize_packet(0, 1338)).await.unwrap();
  342. }),
  343. );
  344. tasks.0.await.unwrap();
  345. tasks.1.await.unwrap();
  346. })
  347. .await;
  348. }
  349. /// Test correct behaviour if the connection is closed during a call.
  350. #[tokio::test]
  351. async fn test_connection_closed() {
  352. panic_after(Duration::from_secs(1), async move {
  353. let (pipe1, pipe2) = Pipe::new();
  354. let (mut client, _events) = create::<_, _, TestPayload>(pipe1);
  355. // Client task.
  356. let tasks = (
  357. tokio::spawn(async move {
  358. let status = client.call(42).await;
  359. assert!(
  360. status == Err(RPCError::Closed)
  361. || status == Err(RPCError::Stream(PipeError::Closed))
  362. );
  363. }),
  364. tokio::spawn(async move {
  365. drop(pipe2);
  366. }),
  367. );
  368. tasks.0.await.unwrap();
  369. tasks.1.await.unwrap();
  370. })
  371. .await;
  372. }
  373. /// Test server events.
  374. #[tokio::test]
  375. async fn test_events() {
  376. panic_after(Duration::from_secs(1), async move {
  377. let (pipe1, mut pipe2) = Pipe::new();
  378. let (client, mut events) = create::<_, _, TestPayload>(pipe1);
  379. // Client task.
  380. let tasks = (
  381. tokio::spawn(async move {
  382. assert_eq!(events.next().await, Some(42));
  383. assert_eq!(events.next().await, None);
  384. }),
  385. tokio::spawn(async move {
  386. pipe2
  387. .send(serialize_packet(SERVER_EVENT_ID, 42))
  388. .await
  389. .unwrap();
  390. drop(pipe2);
  391. }),
  392. );
  393. tasks.0.await.unwrap();
  394. tasks.1.await.unwrap();
  395. })
  396. .await;
  397. }
  398. // TODO:
  399. // - Test for unparseable response.
  400. // - Test whether the client correctly closes the connection and whether futures return an error
  401. // in this case.
  402. }