| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- use std::collections::BTreeMap;
- use std::time::Duration;
- use std::{fmt, mem};
-
- use core::pin::Pin;
- use futures::future::Future;
- use futures::sink::{Sink, SinkExt};
- use futures::stream::{Stream, StreamExt};
- use futures::task::{Context, Poll};
- use log::error;
- use pin_project::pin_project;
- use rmps::Serializer;
- use serde::de::DeserializeOwned;
- use serde::Serialize;
- use tokio::sync::{mpsc, oneshot};
- use tokio::time::{delay_for, Delay};
-
- use super::idalloc::IdAllocator;
- use super::{Packet, SERVER_EVENT_ID};
-
- pub fn create<StreamType, ErrorType, PacketType>(
- stream: StreamType,
- ) -> (
- RPCInterface<PacketType, ErrorType>,
- ServerEventStream<PacketType>,
- )
- where
- StreamType: Stream<Item = Result<Vec<u8>, ErrorType>>
- + Sink<Vec<u8>, Error = ErrorType>
- + Unpin
- + Send
- + 'static,
- PacketType: DeserializeOwned + Serialize + Send + 'static,
- ErrorType: fmt::Debug + Send + 'static,
- {
- // TODO: Does this channel have to be unbounded? We can just drop the sender, and the receiver
- // will notice.
- let (cmd_send, cmd_receive) = mpsc::unbounded_channel();
- let (event_send, event_receive) = mpsc::channel(1);
- tokio::spawn(async move {
- let task = ClientConnectionTask::new(stream, cmd_receive, event_send);
- task.run().await;
- });
-
- (RPCInterface { cmd_send }, Box::pin(event_receive))
- }
-
- enum ClientTaskCommand<PacketType, ErrorType> {
- Close,
- SendPacket(CallData<PacketType, ErrorType>),
- }
-
- struct ClientConnectionTask<StreamType, ErrorType, PacketType> {
- free_ids: IdAllocator,
- calls_in_progress: BTreeMap<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>,
- stream: StreamType,
- cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
- event_send: mpsc::Sender<PacketType>,
- }
-
- impl<StreamType, ErrorType, PacketType> ClientConnectionTask<StreamType, ErrorType, PacketType>
- where
- PacketType: DeserializeOwned + Serialize,
- StreamType:
- Stream<Item = Result<Vec<u8>, ErrorType>> + Sink<Vec<u8>, Error = ErrorType> + Unpin,
- ErrorType: fmt::Debug,
- {
- fn new(
- stream: StreamType,
- cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
- event_send: mpsc::Sender<PacketType>,
- ) -> Self {
- Self {
- free_ids: IdAllocator::new(),
- calls_in_progress: BTreeMap::new(),
- stream: stream,
- cmd_receive: cmd_receive,
- event_send,
- }
- }
-
- async fn run(mut self) {
- let mut connection_closed = false;
- loop {
- tokio::select! {
- packet = &mut self.stream.next() => {
- match packet {
- Some(Ok(data)) => {
- self.received_packet(&data).await;
- }
- Some(Err(e)) => {
- error!("Network error: {:?}", e);
- }
- None => {
- // Stream closed. We still need to return errors for any
- // packet sent by the client, so do not exit the loop
- // immediately but rather return an error for any
- // incoming packet.
- connection_closed = true;
- break;
- }
- }
- }
- cmd = &mut self.cmd_receive.next() => {
- match cmd {
- Some(ClientTaskCommand::Close) => {
- // The RPC interface for this connection was dropped, so
- // we just exit the loop and close the connection.
- break;
- }
- Some(ClientTaskCommand::SendPacket(call)) => {
- self.send_packet(call).await;
- }
- None => {
- // This should never happen.
- break;
- }
- }
- }
- }
- }
-
- drop(self.event_send);
-
- // Return errors for all pending packets as there will be no responses.
- let mut calls_in_progress = BTreeMap::new();
- mem::swap(&mut calls_in_progress, &mut self.calls_in_progress);
- for (_, result) in calls_in_progress {
- result.send(Err(RPCError::Closed)).ok();
- }
-
- if connection_closed {
- // The stream was closed, but the command channel was not. We still need
- // to return errors for any packet sent by the client, so do not exit
- // the loop immediately but rather return an error for any incoming
- // packet.
- client_connection_closed(self.cmd_receive).await;
- } else {
- // The command channel was closed, but the stream was not. Gracefully
- // close the stream.
- // TODO
- }
- }
-
- /// Sends a packet.
- ///
- /// # Errors
- ///
- /// If the connection is closed, the function does *not* return `RPCError::Closed`, but rather
- /// returns the corresponding stream error.
- async fn send_packet(&mut self, call: CallData<PacketType, ErrorType>) {
- // Add an ID for the outgoing packet.
- let id = self.free_ids.alloc();
- let packet = Packet {
- id: id,
- payload: call.data,
- };
- let mut serialized = Vec::new();
- packet
- .serialize(
- &mut Serializer::new(&mut serialized)
- .with_struct_map()
- .with_string_variants(),
- )
- .unwrap();
-
- // Send the serialized packet.
- // TODO: Send without flushing?
- match self.stream.send(serialized).await {
- Ok(()) => {
- self.calls_in_progress.insert(id, call.result);
- }
- Err(e) => {
- // We potentially leak the ID here, but that is
- // okay as the state of the call is not well
- // defined anymore.
- call.result.send(Err(RPCError::Stream(e))).ok();
- }
- }
- }
-
- async fn received_packet(&mut self, data: &[u8]) {
- match rmps::decode::from_slice(data) as Result<Packet<PacketType>, _> {
- Ok(deserialized) => {
- if deserialized.id == SERVER_EVENT_ID {
- // We received an asynchronous event from the server.
- self.event_send.send(deserialized.payload).await.ok();
- } else {
- // We received a response to a call - find the call and forward
- // the reply.
- if let Some(call) = self.calls_in_progress.remove(&deserialized.id) {
- self.free_ids.free(deserialized.id);
- call.send(Ok(deserialized.payload)).ok();
- } else {
- error!(
- "Received a reply for an unknown call ({:?})!",
- deserialized.id
- );
- }
- }
- }
- Err(e) => {
- // We received an invalid packet. We cannot determine for which
- // call the packet was, so we just ignore it. The corresponding
- // call will freeze until the connection is dropped.
- error!("Invalid packet received: {:?}", e);
- }
- };
- }
- }
-
- async fn client_connection_closed<PacketType, ErrorType>(
- mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
- ) {
- loop {
- match cmd_receive.recv().await {
- Some(ClientTaskCommand::Close) => {
- // The connection was already closed.
- break;
- }
- Some(ClientTaskCommand::SendPacket(call)) => {
- // The connection was closed, so return an error for any
- // packet sent.
- call.result.send(Err(RPCError::Closed)).ok();
- }
- None => {
- // This should never happen.
- break;
- }
- }
- }
- }
-
- pub struct RPCInterface<PacketType, ErrorType> {
- cmd_send: mpsc::UnboundedSender<ClientTaskCommand<PacketType, ErrorType>>,
- }
-
- impl<PacketType, ErrorType> RPCInterface<PacketType, ErrorType> {
- pub fn call(&mut self, call: PacketType) -> RPC<PacketType, ErrorType> {
- let (sender, receiver) = oneshot::channel();
- let call_data = CallData {
- data: call,
- result: sender,
- };
- if let Err(_) = self.cmd_send.send(ClientTaskCommand::SendPacket(call_data)) {
- panic!("could not send packet to connection task");
- }
-
- RPC {
- result: receiver,
- timeout: None,
- }
- }
- }
-
- impl<PacketType, ErrorType> Drop for RPCInterface<PacketType, ErrorType> {
- fn drop(&mut self) {
- // Stop the connection loop and close the connection.
- // send() only fails if the other end has already been closed - we must
- // not unwrap() or expect() because then nested panics might occur.
- self.cmd_send.send(ClientTaskCommand::Close).ok();
- }
- }
-
- struct CallData<PacketType, ErrorType> {
- data: PacketType,
- result: oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>,
- }
-
- #[pin_project]
- pub struct RPC<PacketType, ErrorType> {
- #[pin]
- result: oneshot::Receiver<Result<PacketType, RPCError<ErrorType>>>,
- #[pin]
- timeout: Option<Delay>,
- }
-
- impl<PacketType, ErrorType> RPC<PacketType, ErrorType> {
- pub fn with_timeout(mut self, timeout: Duration) -> Self {
- self.timeout = Some(delay_for(timeout));
- self
- }
- }
-
- impl<PacketType, ErrorType> Future for RPC<PacketType, ErrorType> {
- type Output = Result<PacketType, RPCError<ErrorType>>;
-
- fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
- let this = self.project();
-
- // Try receiving from the receiver.
- match this.result.poll(cx) {
- Poll::Ready(Ok(result)) => return Poll::Ready(result),
- Poll::Ready(Err(_)) => {
- // The channel was closed.
- return Poll::Ready(Err(RPCError::Closed));
- }
- Poll::Pending => (),
- }
-
- // Nothing received, check the timeout instead.
- if this.timeout.is_some() {
- match this.timeout.as_pin_mut().unwrap().poll(cx) {
- Poll::Ready(()) => return Poll::Ready(Err(RPCError::Timeout)),
- Poll::Pending => (),
- }
- }
-
- Poll::Pending
- }
- }
-
- /// Stream of asynchronous packets received from the server.
- type ServerEventStream<PacketType> = Pin<Box<mpsc::Receiver<PacketType>>>;
-
- #[derive(Debug, PartialEq)]
- pub enum RPCError<StreamError> {
- Closed,
- Timeout,
- Stream(StreamError),
- }
-
- #[cfg(test)]
- mod tests {
- use super::*;
- use crate::network::test_utils::{panic_after, Pipe, PipeError};
- use crate::network::Packet;
-
- type TestPayload = u32;
-
- fn serialize_packet<T>(id: u32, payload: T) -> Vec<u8>
- where
- T: Serialize,
- {
- let packet = Packet { id, payload };
- let mut serialized = Vec::new();
- packet
- .serialize(
- &mut Serializer::new(&mut serialized)
- .with_struct_map()
- .with_string_variants(),
- )
- .unwrap();
- serialized
- }
-
- /// Test the basic RPC interface.
- #[tokio::test]
- async fn test_rpc_client() {
- panic_after(Duration::from_secs(1), async move {
- let (pipe1, mut pipe2) = Pipe::new();
- let (mut client, events) = create::<_, _, TestPayload>(pipe1);
- println!("init");
-
- // Client task.
- let tasks = (
- tokio::spawn(async move {
- assert_eq!(client.call(42).await, Ok(1337));
- let delayed = client.call(43);
- assert_eq!(client.call(44).await, Ok(1339));
- assert_eq!(delayed.await, Ok(1338));
- }),
- tokio::spawn(async move {
- // Invalid ID, should be ignored.
- pipe2.send(serialize_packet(42, 42)).await.unwrap();
-
- assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 42))));
- pipe2.send(serialize_packet(0, 1337)).await.unwrap();
-
- // Network errors shall be ignored if the stream is not closed.
- pipe2.inject_error();
-
- // The ID has to be reused, and a second concurrent call has to have a
- // different ID.
- assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(0, 43))));
- assert_eq!(pipe2.next().await, Some(Ok(serialize_packet(1, 44))));
- pipe2.send(serialize_packet(1, 1339)).await.unwrap();
- pipe2.send(serialize_packet(0, 1338)).await.unwrap();
- }),
- );
-
- tasks.0.await.unwrap();
- tasks.1.await.unwrap();
- })
- .await;
- }
-
- /// Test correct behaviour if the connection is closed during a call.
- #[tokio::test]
- async fn test_connection_closed() {
- panic_after(Duration::from_secs(1), async move {
- let (pipe1, pipe2) = Pipe::new();
- let (mut client, _events) = create::<_, _, TestPayload>(pipe1);
-
- // Client task.
- let tasks = (
- tokio::spawn(async move {
- let status = client.call(42).await;
- assert!(
- status == Err(RPCError::Closed)
- || status == Err(RPCError::Stream(PipeError::Closed))
- );
- }),
- tokio::spawn(async move {
- drop(pipe2);
- }),
- );
-
- tasks.0.await.unwrap();
- tasks.1.await.unwrap();
- })
- .await;
- }
-
- /// Test server events.
- #[tokio::test]
- async fn test_events() {
- panic_after(Duration::from_secs(1), async move {
- let (pipe1, mut pipe2) = Pipe::new();
- let (client, mut events) = create::<_, _, TestPayload>(pipe1);
-
- // Client task.
- let tasks = (
- tokio::spawn(async move {
- assert_eq!(events.next().await, Some(42));
- assert_eq!(events.next().await, None);
- }),
- tokio::spawn(async move {
- pipe2
- .send(serialize_packet(SERVER_EVENT_ID, 42))
- .await
- .unwrap();
- drop(pipe2);
- }),
- );
-
- tasks.0.await.unwrap();
- tasks.1.await.unwrap();
- })
- .await;
- }
-
- // TODO:
- // - Test for unparseable response.
- // - Test whether the client correctly closes the connection and whether futures return an error
- // in this case.
- }
|