use std::collections::BTreeMap; use std::time::Duration; use std::{fmt, mem}; use core::pin::Pin; use futures::future::Future; use futures::sink::{Sink, SinkExt}; use futures::stream::{Stream, StreamExt}; use futures::task::{Context, Poll}; use log::error; use pin_project::pin_project; use rmps::Serializer; use serde::de::DeserializeOwned; use serde::Serialize; use tokio::sync::{mpsc, oneshot}; use tokio::time::{delay_for, Delay}; use super::idalloc::IdAllocator; use super::{Packet, SERVER_EVENT_ID}; pub fn create( stream: StreamType, ) -> ( RPCInterface, ServerEventStream, ) where StreamType: Stream, ErrorType>> + Sink, Error = ErrorType> + Unpin + Send + 'static, PacketType: DeserializeOwned + Serialize + Send + 'static, ErrorType: fmt::Debug + Send + 'static, { // TODO: Does this channel have to be unbounded? We can just drop the sender, and the receiver // will notice. let (cmd_send, cmd_receive) = mpsc::unbounded_channel(); let (event_send, event_receive) = mpsc::channel(1); tokio::spawn(async move { let task = ClientConnectionTask::new(stream, cmd_receive, event_send); task.run().await; }); (RPCInterface { cmd_send }, Box::pin(event_receive)) } enum ClientTaskCommand { Close, SendPacket(CallData), } struct ClientConnectionTask { free_ids: IdAllocator, calls_in_progress: BTreeMap>>>, stream: StreamType, cmd_receive: mpsc::UnboundedReceiver>, event_send: mpsc::Sender, } impl ClientConnectionTask where PacketType: DeserializeOwned + Serialize, StreamType: Stream, ErrorType>> + Sink, Error = ErrorType> + Unpin, ErrorType: fmt::Debug, { fn new( stream: StreamType, cmd_receive: mpsc::UnboundedReceiver>, event_send: mpsc::Sender, ) -> Self { Self { free_ids: IdAllocator::new(), calls_in_progress: BTreeMap::new(), stream: stream, cmd_receive: cmd_receive, event_send, } } async fn run(mut self) { let mut connection_closed = false; loop { tokio::select! { packet = &mut self.stream.next() => { match packet { Some(Ok(data)) => { self.received_packet(&data).await; } Some(Err(e)) => { error!("Network error: {:?}", e); } None => { // Stream closed. We still need to return errors for any // packet sent by the client, so do not exit the loop // immediately but rather return an error for any // incoming packet. connection_closed = true; break; } } } cmd = &mut self.cmd_receive.next() => { match cmd { Some(ClientTaskCommand::Close) => { // The RPC interface for this connection was dropped, so // we just exit the loop and close the connection. break; } Some(ClientTaskCommand::SendPacket(call)) => { self.send_packet(call).await; } None => { // This should never happen. break; } } } } } drop(self.event_send); // Return errors for all pending packets as there will be no responses. let mut calls_in_progress = BTreeMap::new(); mem::swap(&mut calls_in_progress, &mut self.calls_in_progress); for (_, result) in calls_in_progress { result.send(Err(RPCError::Closed)).ok(); } if connection_closed { // The stream was closed, but the command channel was not. We still need // to return errors for any packet sent by the client, so do not exit // the loop immediately but rather return an error for any incoming // packet. client_connection_closed(self.cmd_receive).await; } else { // The command channel was closed, but the stream was not. Gracefully // close the stream. // TODO } } /// Sends a packet. /// /// # Errors /// /// If the connection is closed, the function does *not* return `RPCError::Closed`, but rather /// returns the corresponding stream error. async fn send_packet(&mut self, call: CallData) { // Add an ID for the outgoing packet. let id = self.free_ids.alloc(); let packet = Packet { id: id, payload: call.data, }; let mut serialized = Vec::new(); packet .serialize( &mut Serializer::new(&mut serialized) .with_struct_map() .with_string_variants(), ) .unwrap(); // Send the serialized packet. // TODO: Send without flushing? match self.stream.send(serialized).await { Ok(()) => { self.calls_in_progress.insert(id, call.result); } Err(e) => { // We potentially leak the ID here, but that is // okay as the state of the call is not well // defined anymore. call.result.send(Err(RPCError::Stream(e))).ok(); } } } async fn received_packet(&mut self, data: &[u8]) { match rmps::decode::from_slice(data) as Result, _> { Ok(deserialized) => { if deserialized.id == SERVER_EVENT_ID { // We received an asynchronous event from the server. self.event_send.send(deserialized.payload).await.ok(); } else { // We received a response to a call - find the call and forward // the reply. if let Some(call) = self.calls_in_progress.remove(&deserialized.id) { self.free_ids.free(deserialized.id); call.send(Ok(deserialized.payload)).ok(); } else { error!( "Received a reply for an unknown call ({:?})!", deserialized.id ); } } } Err(e) => { // We received an invalid packet. We cannot determine for which // call the packet was, so we just ignore it. The corresponding // call will freeze until the connection is dropped. error!("Invalid packet received: {:?}", e); } }; } } async fn client_connection_closed( mut cmd_receive: mpsc::UnboundedReceiver>, ) { loop { match cmd_receive.recv().await { Some(ClientTaskCommand::Close) => { // The connection was already closed. break; } Some(ClientTaskCommand::SendPacket(call)) => { // The connection was closed, so return an error for any // packet sent. call.result.send(Err(RPCError::Closed)).ok(); } None => { // This should never happen. break; } } } } pub struct RPCInterface { cmd_send: mpsc::UnboundedSender>, } impl RPCInterface { pub fn call(&mut self, call: PacketType) -> RPC { let (sender, receiver) = oneshot::channel(); let call_data = CallData { data: call, result: sender, }; if let Err(_) = self.cmd_send.send(ClientTaskCommand::SendPacket(call_data)) { panic!("could not send packet to connection task"); } RPC { result: receiver, timeout: None, } } } impl Drop for RPCInterface { fn drop(&mut self) { // Stop the connection loop and close the connection. // send() only fails if the other end has already been closed - we must // not unwrap() or expect() because then nested panics might occur. self.cmd_send.send(ClientTaskCommand::Close).ok(); } } struct CallData { data: PacketType, result: oneshot::Sender>>, } #[pin_project] pub struct RPC { #[pin] result: oneshot::Receiver>>, #[pin] timeout: Option, } impl RPC { pub fn with_timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(delay_for(timeout)); self } } impl Future for RPC { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let this = self.project(); // Try receiving from the receiver. match this.result.poll(cx) { Poll::Ready(Ok(result)) => return Poll::Ready(result), Poll::Ready(Err(_)) => { // The channel was closed. return Poll::Ready(Err(RPCError::Closed)); } Poll::Pending => (), } // Nothing received, check the timeout instead. if this.timeout.is_some() { match this.timeout.as_pin_mut().unwrap().poll(cx) { Poll::Ready(()) => return Poll::Ready(Err(RPCError::Timeout)), Poll::Pending => (), } } Poll::Pending } } /// Stream of asynchronous packets received from the server. type ServerEventStream = Pin>>; #[derive(Debug, PartialEq)] pub enum RPCError { Closed, Timeout, Stream(StreamError), } #[cfg(test)] mod tests { use super::*; use crate::network::test_utils::{panic_after, Pipe, PipeError}; use crate::network::Packet; type TestPayload = u32; fn serialize_packet(id: u32, payload: T) -> Vec where T: Serialize, { let packet = Packet { id, payload }; let mut serialized = Vec::new(); packet .serialize( &mut Serializer::new(&mut serialized) .with_struct_map() .with_string_variants(), ) .unwrap(); serialized } /// Test the basic RPC interface. #[tokio::test] async fn test_rpc_client() { panic_after(Duration::from_secs(1), async move { let (pipe1, mut pipe2) = Pipe::new(); let (mut client, events) = create::<_, _, TestPayload>(pipe1); println!("init"); // Client task. let tasks = ( tokio::spawn(async move { assert_eq!(client.call(42).await, Ok(1337)); let delayed = client.call(43); assert_eq!(client.call(44).await, Ok(1339)); assert_eq!(delayed.await, Ok(1338)); }), tokio::spawn(async move { // Invalid ID, should be ignored. pipe2.send(serialize_packet(42, 42)).await.unwrap(); assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 42)))); pipe2.send(serialize_packet(0, 1337)).await.unwrap(); // Network errors shall be ignored if the stream is not closed. pipe2.inject_error(); // The ID has to be reused, and a second concurrent call has to have a // different ID. assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 43)))); assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(1, 44)))); pipe2.send(serialize_packet(1, 1339)).await.unwrap(); pipe2.send(serialize_packet(0, 1338)).await.unwrap(); }), ); tasks.0.await.unwrap(); tasks.1.await.unwrap(); }) .await; } /// Test correct behaviour if the connection is closed during a call. #[tokio::test] async fn test_connection_closed() { panic_after(Duration::from_secs(1), async move { let (pipe1, pipe2) = Pipe::new(); let (mut client, _events) = create::<_, _, TestPayload>(pipe1); // Client task. let tasks = ( tokio::spawn(async move { let status = client.call(42).await; assert!( status == Err(RPCError::Closed) || status == Err(RPCError::Stream(PipeError::Closed)) ); }), tokio::spawn(async move { drop(pipe2); }), ); tasks.0.await.unwrap(); tasks.1.await.unwrap(); }) .await; } /// Test server events. #[tokio::test] async fn test_events() { panic_after(Duration::from_secs(1), async move { let (pipe1, mut pipe2) = Pipe::new(); let (client, mut events) = create::<_, _, TestPayload>(pipe1); // Client task. let tasks = ( tokio::spawn(async move { assert_eq!(events.next().await, Some(42)); assert_eq!(events.next().await, None); }), tokio::spawn(async move { pipe2 .send(serialize_packet(SERVER_EVENT_ID, 42)) .await .unwrap(); drop(pipe2); }), ); tasks.0.await.unwrap(); tasks.1.await.unwrap(); }) .await; } // TODO: // - Test for unparseable response. // - Test whether the client correctly closes the connection and whether futures return an error // in this case. }