浏览代码

Implement client RPC logic.

父节点
当前提交
e136c9d21a
共有 3 个文件被更改,包括 156 次插入72 次删除
  1. 153
    47
      src/network/client.rs
  2. 0
    16
      src/network/packet.rs
  3. 3
    9
      src/network/server.rs

+ 153
- 47
src/network/client.rs 查看文件

@@ -1,57 +1,106 @@
1
+use std::collections::BTreeMap;
2
+use std::fmt;
1 3
 use std::time::Duration;
2 4
 
3
-use async_tungstenite::{tokio::connect_async, tungstenite::Message, WebSocketStream};
4 5
 use core::pin::Pin;
5 6
 use futures::future::{self, Either, Future};
7
+use futures::sink::{Sink, SinkExt};
8
+use futures::stream::{Stream, StreamExt};
6 9
 use futures::task::{Context, Poll};
7
-use tokio::stream::Stream;
8
-use tokio::stream::StreamExt;
10
+use rmps::Serializer;
11
+use serde::de::DeserializeOwned;
12
+use serde::Serialize;
9 13
 use tokio::sync::{mpsc, oneshot};
10 14
 use tokio::time::{delay_for, Delay};
11
-use url::Url;
12 15
 
13
-use super::packet::IncomingPacket;
16
+use super::idalloc::IdAllocator;
17
+use super::packet::{Packet, SERVER_EVENT_ID};
14 18
 
15 19
 pub fn create<StreamType, ErrorType, PacketType>(
16 20
     stream: StreamType,
17 21
 ) -> (
18
-    RPCInterface<StreamType, PacketType>,
22
+    RPCInterface<PacketType, ErrorType>,
19 23
     ServerEventStream<PacketType>,
20 24
 )
21 25
 where
22
-    StreamType: Stream<Item = Result<Vec<u8>, ErrorType>> + Unpin + Send + 'static,
26
+    StreamType: Stream<Item = Result<Vec<u8>, ErrorType>>
27
+        + Sink<Vec<u8>, Error = ErrorType>
28
+        + Unpin
29
+        + Send
30
+        + 'static,
31
+    PacketType: DeserializeOwned + Serialize + Send + 'static,
32
+    ErrorType: fmt::Debug + Send + 'static,
23 33
 {
24
-    let (cmd_send, cmd_receive) = mpsc::channel(1);
34
+    // TODO: Does this channel have to be unbounded? We can just drop the sender, and the receiver
35
+    // will notice.
36
+    let (cmd_send, cmd_receive) = mpsc::unbounded_channel();
37
+    let (event_send, event_receive) = mpsc::channel(1);
25 38
     tokio::spawn(async move {
26
-        // TODO: Remove type annotations.
27
-        client_connection_task::<StreamType, ErrorType, PacketType>(stream, cmd_receive).await;
39
+        client_connection_task(stream, cmd_receive, event_send).await;
28 40
     });
29
-    // TODO
30
-    panic!("Not yet implemented.");
31
-}
32 41
 
33
-enum ClientTaskCommand {
34
-    Close,
35
-    SendPackets,
42
+    (
43
+        RPCInterface { cmd_send },
44
+        ServerEventStream {
45
+            packets: Box::pin(event_receive),
46
+        },
47
+    )
36 48
 }
37 49
 
38
-struct OutgoingCallMap {
39
-    // TODO
50
+enum ClientTaskCommand<PacketType, ErrorType> {
51
+    Close,
52
+    SendPacket(CallData<PacketType, ErrorType>),
40 53
 }
41 54
 
42 55
 async fn client_connection_task<StreamType, ErrorType, PacketType>(
43 56
     mut stream: StreamType,
44
-    mut cmd_receive: mpsc::Receiver<ClientTaskCommand>,
57
+    mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
58
+    mut event_send: mpsc::Sender<PacketType>,
45 59
 ) where
46
-    StreamType: Stream<Item = Result<Vec<u8>, ErrorType>> + Unpin,
60
+    PacketType: DeserializeOwned + Serialize,
61
+    StreamType:
62
+        Stream<Item = Result<Vec<u8>, ErrorType>> + Sink<Vec<u8>, Error = ErrorType> + Unpin,
63
+    ErrorType: fmt::Debug,
47 64
 {
65
+    let mut free_ids = IdAllocator::new();
66
+    let mut calls_in_progress = BTreeMap::<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>::new();
67
+
68
+    // TODO: Refactor this function.
69
+
48 70
     let mut connection_closed = false;
49 71
     loop {
50 72
         match future::select(Box::pin(stream.next()), Box::pin(cmd_receive.recv())).await {
51 73
             Either::Left((packet, _cmd_future)) => {
52 74
                 match packet {
53
-                    Some(data) => {
54
-                        // TODO
75
+                    Some(Ok(data)) => {
76
+                        match rmps::decode::from_slice(&data) as Result<Packet<PacketType>, _> {
77
+                            Ok(deserialized) => {
78
+                                if deserialized.id == SERVER_EVENT_ID {
79
+                                    // We received an asynchronous event from the server.
80
+                                    event_send.send(deserialized.payload).await.ok();
81
+                                } else {
82
+                                    // We received a response to a call - find the call and forward
83
+                                    // the reply.
84
+                                    if let Some(call) = calls_in_progress.remove(&deserialized.id) {
85
+                                        call.send(Ok(deserialized.payload)).ok();
86
+                                    } else {
87
+                                        // TODO: Use proper logging functionality.
88
+                                        eprintln!("Received a reply for an unknown call ({:?})!", deserialized.id);
89
+                                    }
90
+                                }
91
+                            }
92
+                            Err(e) => {
93
+                                // We received an invalid packet. We cannot determine for which
94
+                                // call the packet was, so we just ignore it. The corresponding
95
+                                // call will freeze until the connection is dropped.
96
+                                // TODO: Use proper logging functionality.
97
+                                eprintln!("Invalid packet received: {:?}", e);
98
+                            }
99
+                        };
100
+                    }
101
+                    Some(Err(e)) => {
102
+                        // TODO: Use proper logging functionality.
103
+                        eprintln!("Network error: {:?}", e);
55 104
                     }
56 105
                     None => {
57 106
                         // Stream closed. We still need to return errors for any
@@ -68,14 +117,38 @@ async fn client_connection_task<StreamType, ErrorType, PacketType>(
68 117
                     Some(ClientTaskCommand::Close) => {
69 118
                         // The RPC interface for this connection was dropped, so
70 119
                         // we just exit the loop and close the connection.
71
-
72
-                        // We need to return errors for all pending packets
73
-                        // first, though.
74
-                        // TODO
75 120
                         break;
76 121
                     }
77
-                    Some(ClientTaskCommand::SendPackets) => {
78
-                        // TODO
122
+                    Some(ClientTaskCommand::SendPacket(call)) => {
123
+                        // Add an ID for the outgoing packet.
124
+                        let id = free_ids.alloc();
125
+                        let packet = Packet {
126
+                            id: id,
127
+                            response: false,
128
+                            payload: call.data,
129
+                        };
130
+                        let mut serialized = Vec::new();
131
+                        packet
132
+                            .serialize(
133
+                                &mut Serializer::new(&mut serialized)
134
+                                    .with_struct_map()
135
+                                    .with_string_variants(),
136
+                            )
137
+                            .unwrap();
138
+
139
+                        // Send the serialized packet.
140
+                        // TODO: Send without flushing?
141
+                        match stream.send(serialized).await {
142
+                            Ok(()) => {
143
+                                calls_in_progress.insert(id, call.result);
144
+                            }
145
+                            Err(e) => {
146
+                                // We potentially leak the ID here, but that is
147
+                                // okay as the state of the call is not well
148
+                                // defined anymore.
149
+                                call.result.send(Err(RPCError::Stream(e))).ok();
150
+                            }
151
+                        }
79 152
                     }
80 153
                     None => {
81 154
                         // This should never happen.
@@ -86,6 +159,11 @@ async fn client_connection_task<StreamType, ErrorType, PacketType>(
86 159
         };
87 160
     }
88 161
 
162
+    // Return errors for all pending packets as there will be no responses.
163
+    for (_, result) in calls_in_progress {
164
+        result.send(Err(RPCError::Closed)).ok();
165
+    }
166
+
89 167
     if connection_closed {
90 168
         // The stream was closed, but the command channel was not. We still need
91 169
         // to return errors for any packet sent by the client, so do not exit
@@ -99,17 +177,19 @@ async fn client_connection_task<StreamType, ErrorType, PacketType>(
99 177
     }
100 178
 }
101 179
 
102
-async fn client_connection_closed(mut cmd_receive: mpsc::Receiver<ClientTaskCommand>) {
180
+async fn client_connection_closed<PacketType, ErrorType>(
181
+    mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
182
+) {
103 183
     loop {
104 184
         match cmd_receive.recv().await {
105 185
             Some(ClientTaskCommand::Close) => {
106 186
                 // The connection was already closed.
107 187
                 break;
108 188
             }
109
-            Some(ClientTaskCommand::SendPackets) => {
189
+            Some(ClientTaskCommand::SendPacket(call)) => {
110 190
                 // The connection was closed, so return an error for any
111 191
                 // packet sent.
112
-                // TODO
192
+                call.result.send(Err(RPCError::Closed)).ok();
113 193
             }
114 194
             None => {
115 195
                 // This should never happen.
@@ -119,22 +199,47 @@ async fn client_connection_closed(mut cmd_receive: mpsc::Receiver<ClientTaskComm
119 199
     }
120 200
 }
121 201
 
122
-pub struct RPCInterface<StreamType, PacketType> {
123
-    stream: StreamType,
124
-    _packet: Option<PacketType>, // TODO: Remove.
202
+pub struct RPCInterface<PacketType, ErrorType> {
203
+    cmd_send: mpsc::UnboundedSender<ClientTaskCommand<PacketType, ErrorType>>,
125 204
 }
126 205
 
127
-impl<StreamType, PacketType> RPCInterface<StreamType, PacketType> {
128
-    fn call(_call: &PacketType) -> RPC<StreamType, PacketType> {
129
-        // TODO
130
-        panic!("Not yet implemented.");
206
+impl<PacketType, ErrorType> RPCInterface<PacketType, ErrorType> {
207
+    pub async fn call(&mut self, call: PacketType) -> RPC<PacketType, ErrorType> {
208
+        let (sender, receiver) = oneshot::channel();
209
+        let call_data = CallData {
210
+            data: call,
211
+            result: sender,
212
+        };
213
+        if let Err(_) = self.cmd_send.send(ClientTaskCommand::SendPacket(call_data)) {
214
+            panic!("could not send packet to connection task");
215
+        }
216
+
217
+        RPC {
218
+            result: Box::pin(receiver),
219
+            timeout: None,
220
+        }
131 221
     }
132 222
 }
133 223
 
224
+impl<PacketType, ErrorType> Drop for RPCInterface<PacketType, ErrorType> {
225
+    fn drop(&mut self) {
226
+        // Stop the connection loop and close the connection.
227
+        if let Err(_) = self.cmd_send.send(ClientTaskCommand::Close) {
228
+            panic!("could not send close signal to connection task");
229
+        }
230
+    }
231
+}
232
+
233
+struct CallData<PacketType, ErrorType> {
234
+    data: PacketType,
235
+    result: oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>,
236
+}
237
+
134 238
 pub struct RPC<PacketType, ErrorType> {
135
-    result: Pin<Box<oneshot::Receiver<Result<IncomingPacket<PacketType>, ErrorType>>>>,
239
+    result: Pin<Box<oneshot::Receiver<Result<PacketType, RPCError<ErrorType>>>>>,
136 240
     timeout: Option<Pin<Box<Delay>>>,
137 241
 }
242
+
138 243
 impl<PacketType, ErrorType> RPC<PacketType, ErrorType> {
139 244
     pub fn with_timeout(mut self, timeout: Duration) -> Self {
140 245
         self.timeout = Some(Box::pin(delay_for(timeout)));
@@ -143,7 +248,7 @@ impl<PacketType, ErrorType> RPC<PacketType, ErrorType> {
143 248
 }
144 249
 
145 250
 impl<PacketType, ErrorType> Future for RPC<PacketType, ErrorType> {
146
-    type Output = Result<IncomingPacket<PacketType>, RPCError<ErrorType>>;
251
+    type Output = Result<PacketType, RPCError<ErrorType>>;
147 252
 
148 253
     fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
149 254
         // Safe, as we will not move self_.
@@ -151,7 +256,7 @@ impl<PacketType, ErrorType> Future for RPC<PacketType, ErrorType> {
151 256
 
152 257
         // Try receiving from the receiver.
153 258
         match Pin::as_mut(&mut self_.result).poll(cx) {
154
-            Poll::Ready(Ok(result)) => return Poll::Ready(result.map_err(|x| RPCError::Stream(x))),
259
+            Poll::Ready(Ok(result)) => return Poll::Ready(result),
155 260
             Poll::Ready(Err(_)) => {
156 261
                 // The channel was closed.
157 262
                 return Poll::Ready(Err(RPCError::Closed));
@@ -173,11 +278,11 @@ impl<PacketType, ErrorType> Future for RPC<PacketType, ErrorType> {
173 278
 
174 279
 // TODO: Do we even need this type?
175 280
 pub struct ServerEventStream<PacketType> {
176
-    packets: Pin<Box<mpsc::Receiver<IncomingPacket<PacketType>>>>,
281
+    packets: Pin<Box<mpsc::Receiver<PacketType>>>,
177 282
 }
178 283
 
179 284
 impl<PacketType> Stream for ServerEventStream<PacketType> {
180
-    type Item = IncomingPacket<PacketType>;
285
+    type Item = PacketType;
181 286
 
182 287
     fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
183 288
         // Safe, as we will not move self_.
@@ -187,19 +292,20 @@ impl<PacketType> Stream for ServerEventStream<PacketType> {
187 292
     }
188 293
 }
189 294
 
295
+#[derive(Debug)]
190 296
 pub enum RPCError<StreamError> {
191 297
     Closed,
192 298
     Timeout,
193
-    Stream(StreamError)
299
+    Stream(StreamError),
194 300
 }
195 301
 
196
-pub async fn low_level_integration_test_client(address: String) {
197
-    use super::websocket::WebSocketConnection;
302
+pub async fn low_level_integration_test_client(_address: String) {
303
+    /*use super::websocket::WebSocketConnection;
198 304
 
199 305
     let url = Url::parse(&address).unwrap();
200 306
     let (mut ws_stream, _) = connect_async(url).await.unwrap();
201 307
     let mut conn = WebSocketConnection::new(ws_stream);
202
-    let (rpc_interface, server_events) = create::<_, _, u32>(conn);
308
+    let (rpc_interface, server_events) = create::<_, _, u32>(conn);*/
203 309
     // TODO
204 310
     panic!("Not yet implemented.");
205 311
 }

+ 0
- 16
src/network/packet.rs 查看文件

@@ -27,26 +27,10 @@
27 27
 //! a way that no such temporal information is necessary to process the events.
28 28
 //! TODO: Example to demonstrate the problem.
29 29
 
30
-//use futures::sink::Sink;
31
-//use futures::stream::Stream;
32
-//use serde::de::DeserializeOwned;
33 30
 use serde::{Deserialize, Serialize};
34 31
 
35 32
 pub const SERVER_EVENT_ID: u32 = u32::max_value();
36 33
 
37
-//pub trait PacketConnection: Sink<&Self::Payload> + Stream<IncomingPacket<Payload>> {
38
-//    type Payload: Serialize + DeserializeOwned;
39
-//}
40
-
41
-/// Type which holds a byte vector received via the network and which allows access to the
42
-/// deserialized packet.
43
-///
44
-/// The type can be used for efficient zero-copy deserialization of large packets.
45
-pub struct IncomingPacket<Payload> {
46
-    // TODO
47
-    _packet: Payload,
48
-}
49
-
50 34
 #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
51 35
 pub struct Packet<Payload> {
52 36
     pub id: u32,

+ 3
- 9
src/network/server.rs 查看文件

@@ -1,15 +1,10 @@
1
-use super::packet::IncomingPacket;
2
-
3 1
 // Trait that is implemented by any RPC server to process incoming calls.
4 2
 trait RPCServer {
5 3
     type PacketType;
6 4
     type NetworkError;
7 5
 
8
-    fn incoming_call<Response>(
9
-        &mut self,
10
-        call: IncomingPacket<Self::PacketType>,
11
-        response: Response,
12
-    ) where
6
+    fn incoming_call<Response>(&mut self, call: Self::PacketType, response: Response)
7
+    where
13 8
         Response: RPCResponse<PacketType = Self::PacketType, NetworkError = Self::NetworkError>;
14 9
 }
15 10
 
@@ -29,11 +24,10 @@ pub struct LowLevelIntegrationTestServer {
29 24
 impl LowLevelIntegrationTestServer {
30 25
     pub async fn start(_bind_address: &str) -> LowLevelIntegrationTestServer {
31 26
         // TODO
32
-        LowLevelIntegrationTestServer{}
27
+        LowLevelIntegrationTestServer {}
33 28
     }
34 29
 
35 30
     pub async fn stop(self) {
36 31
         // TODO
37 32
     }
38 33
 }
39
-

正在加载...
取消
保存