Przeglądaj źródła

Slightly refactor the low-level network client code.

Mathias Gottschlag 5 lat temu
rodzic
commit
9a68d46fc0
1 zmienionych plików z 137 dodań i 111 usunięć
  1. 137
    111
      src/network/client.rs

+ 137
- 111
src/network/client.rs Wyświetl plik

@@ -1,9 +1,9 @@
1 1
 use std::collections::BTreeMap;
2
-use std::fmt;
3 2
 use std::time::Duration;
3
+use std::{fmt, mem};
4 4
 
5 5
 use core::pin::Pin;
6
-use futures::future::{self, Either, Future};
6
+use futures::future::Future;
7 7
 use futures::sink::{Sink, SinkExt};
8 8
 use futures::stream::{Stream, StreamExt};
9 9
 use futures::task::{Context, Poll};
@@ -37,7 +37,8 @@ where
37 37
     let (cmd_send, cmd_receive) = mpsc::unbounded_channel();
38 38
     let (event_send, event_receive) = mpsc::channel(1);
39 39
     tokio::spawn(async move {
40
-        client_connection_task(stream, cmd_receive, event_send).await;
40
+        let task = ClientConnectionTask::new(stream, cmd_receive, event_send);
41
+        task.run().await;
41 42
     });
42 43
 
43 44
     (
@@ -53,131 +54,156 @@ enum ClientTaskCommand<PacketType, ErrorType> {
53 54
     SendPacket(CallData<PacketType, ErrorType>),
54 55
 }
55 56
 
56
-async fn client_connection_task<StreamType, ErrorType, PacketType>(
57
-    mut stream: StreamType,
58
-    mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
59
-    mut event_send: mpsc::Sender<PacketType>,
60
-) where
57
+struct ClientConnectionTask<StreamType, ErrorType, PacketType> {
58
+    free_ids: IdAllocator,
59
+    calls_in_progress: BTreeMap<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>,
60
+    stream: StreamType,
61
+    cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
62
+    event_send: mpsc::Sender<PacketType>,
63
+}
64
+
65
+impl<StreamType, ErrorType, PacketType> ClientConnectionTask<StreamType, ErrorType, PacketType>
66
+where
61 67
     PacketType: DeserializeOwned + Serialize,
62 68
     StreamType:
63 69
         Stream<Item = Result<Vec<u8>, ErrorType>> + Sink<Vec<u8>, Error = ErrorType> + Unpin,
64 70
     ErrorType: fmt::Debug,
65 71
 {
66
-    let mut free_ids = IdAllocator::new();
67
-    let mut calls_in_progress =
68
-        BTreeMap::<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>::new();
69
-
70
-    // TODO: Refactor this function.
72
+    fn new(
73
+        stream: StreamType,
74
+        cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
75
+        event_send: mpsc::Sender<PacketType>,
76
+    ) -> Self {
77
+        Self {
78
+            free_ids: IdAllocator::new(),
79
+            calls_in_progress: BTreeMap::new(),
80
+            stream: stream,
81
+            cmd_receive: cmd_receive,
82
+            event_send,
83
+        }
84
+    }
71 85
 
72
-    let mut connection_closed = false;
73
-    loop {
74
-        match future::select(Box::pin(stream.next()), Box::pin(cmd_receive.recv())).await {
75
-            Either::Left((packet, _cmd_future)) => {
76
-                match packet {
77
-                    Some(Ok(data)) => {
78
-                        match rmps::decode::from_slice(&data) as Result<Packet<PacketType>, _> {
79
-                            Ok(deserialized) => {
80
-                                if deserialized.id == SERVER_EVENT_ID {
81
-                                    // We received an asynchronous event from the server.
82
-                                    event_send.send(deserialized.payload).await.ok();
83
-                                } else {
84
-                                    // We received a response to a call - find the call and forward
85
-                                    // the reply.
86
-                                    if let Some(call) = calls_in_progress.remove(&deserialized.id) {
87
-                                        free_ids.free(deserialized.id);
88
-                                        call.send(Ok(deserialized.payload)).ok();
89
-                                    } else {
90
-                                        error!(
91
-                                            "Received a reply for an unknown call ({:?})!",
92
-                                            deserialized.id
93
-                                        );
94
-                                    }
95
-                                }
96
-                            }
97
-                            Err(e) => {
98
-                                // We received an invalid packet. We cannot determine for which
99
-                                // call the packet was, so we just ignore it. The corresponding
100
-                                // call will freeze until the connection is dropped.
101
-                                error!("Invalid packet received: {:?}", e);
102
-                            }
103
-                        };
104
-                    }
105
-                    Some(Err(e)) => {
106
-                        error!("Network error: {:?}", e);
107
-                    }
108
-                    None => {
109
-                        // Stream closed. We still need to return errors for any
110
-                        // packet sent by the client, so do not exit the loop
111
-                        // immediately but rather return an error for any
112
-                        // incoming packet.
113
-                        connection_closed = true;
114
-                        break;
86
+    async fn run(mut self) {
87
+        let mut connection_closed = false;
88
+        loop {
89
+            tokio::select! {
90
+                packet = &mut self.stream.next() => {
91
+                    match packet {
92
+                        Some(Ok(data)) => {
93
+                            self.received_packet(&data).await;
94
+                        }
95
+                        Some(Err(e)) => {
96
+                            error!("Network error: {:?}", e);
97
+                        }
98
+                        None => {
99
+                            // Stream closed. We still need to return errors for any
100
+                            // packet sent by the client, so do not exit the loop
101
+                            // immediately but rather return an error for any
102
+                            // incoming packet.
103
+                            connection_closed = true;
104
+                            break;
105
+                        }
115 106
                     }
116 107
                 }
117
-            }
118
-            Either::Right((cmd, _stream_future)) => {
119
-                match cmd {
120
-                    Some(ClientTaskCommand::Close) => {
121
-                        // The RPC interface for this connection was dropped, so
122
-                        // we just exit the loop and close the connection.
123
-                        break;
124
-                    }
125
-                    Some(ClientTaskCommand::SendPacket(call)) => {
126
-                        // Add an ID for the outgoing packet.
127
-                        let id = free_ids.alloc();
128
-                        let packet = Packet {
129
-                            id: id,
130
-                            payload: call.data,
131
-                        };
132
-                        let mut serialized = Vec::new();
133
-                        packet
134
-                            .serialize(
135
-                                &mut Serializer::new(&mut serialized)
136
-                                    .with_struct_map()
137
-                                    .with_string_variants(),
138
-                            )
139
-                            .unwrap();
140
-
141
-                        // Send the serialized packet.
142
-                        // TODO: Send without flushing?
143
-                        match stream.send(serialized).await {
144
-                            Ok(()) => {
145
-                                calls_in_progress.insert(id, call.result);
146
-                            }
147
-                            Err(e) => {
148
-                                // We potentially leak the ID here, but that is
149
-                                // okay as the state of the call is not well
150
-                                // defined anymore.
151
-                                call.result.send(Err(RPCError::Stream(e))).ok();
152
-                            }
108
+                cmd = &mut self.cmd_receive.next() => {
109
+                    match cmd {
110
+                        Some(ClientTaskCommand::Close) => {
111
+                            // The RPC interface for this connection was dropped, so
112
+                            // we just exit the loop and close the connection.
113
+                            break;
114
+                        }
115
+                        Some(ClientTaskCommand::SendPacket(call)) => {
116
+                            self.send_packet(call).await;
117
+                        }
118
+                        None => {
119
+                            // This should never happen.
120
+                            break;
153 121
                         }
154
-                    }
155
-                    None => {
156
-                        // This should never happen.
157
-                        break;
158 122
                     }
159 123
                 }
160 124
             }
161
-        };
125
+        }
126
+
127
+        drop(self.event_send);
128
+
129
+        // Return errors for all pending packets as there will be no responses.
130
+        let mut calls_in_progress = BTreeMap::new();
131
+        mem::swap(&mut calls_in_progress, &mut self.calls_in_progress);
132
+        for (_, result) in calls_in_progress {
133
+            result.send(Err(RPCError::Closed)).ok();
134
+        }
135
+
136
+        if connection_closed {
137
+            // The stream was closed, but the command channel was not. We still need
138
+            // to return errors for any packet sent by the client, so do not exit
139
+            // the loop immediately but rather return an error for any incoming
140
+            // packet.
141
+            client_connection_closed(self.cmd_receive).await;
142
+        } else {
143
+            // The command channel was closed, but the stream was not. Gracefully
144
+            // close the stream.
145
+            // TODO
146
+        }
162 147
     }
163 148
 
164
-    drop(event_send);
149
+    async fn send_packet(&mut self, call: CallData<PacketType, ErrorType>) {
150
+        // Add an ID for the outgoing packet.
151
+        let id = self.free_ids.alloc();
152
+        let packet = Packet {
153
+            id: id,
154
+            payload: call.data,
155
+        };
156
+        let mut serialized = Vec::new();
157
+        packet
158
+            .serialize(
159
+                &mut Serializer::new(&mut serialized)
160
+                    .with_struct_map()
161
+                    .with_string_variants(),
162
+            )
163
+            .unwrap();
165 164
 
166
-    // Return errors for all pending packets as there will be no responses.
167
-    for (_, result) in calls_in_progress {
168
-        result.send(Err(RPCError::Closed)).ok();
165
+        // Send the serialized packet.
166
+        // TODO: Send without flushing?
167
+        match self.stream.send(serialized).await {
168
+            Ok(()) => {
169
+                self.calls_in_progress.insert(id, call.result);
170
+            }
171
+            Err(e) => {
172
+                // We potentially leak the ID here, but that is
173
+                // okay as the state of the call is not well
174
+                // defined anymore.
175
+                call.result.send(Err(RPCError::Stream(e))).ok();
176
+            }
177
+        }
169 178
     }
170 179
 
171
-    if connection_closed {
172
-        // The stream was closed, but the command channel was not. We still need
173
-        // to return errors for any packet sent by the client, so do not exit
174
-        // the loop immediately but rather return an error for any incoming
175
-        // packet.
176
-        client_connection_closed(cmd_receive).await;
177
-    } else {
178
-        // The command channel was closed, but the stream was not. Gracefully
179
-        // close the stream.
180
-        // TODO
180
+    async fn received_packet(&mut self, data: &[u8]) {
181
+        match rmps::decode::from_slice(data) as Result<Packet<PacketType>, _> {
182
+            Ok(deserialized) => {
183
+                if deserialized.id == SERVER_EVENT_ID {
184
+                    // We received an asynchronous event from the server.
185
+                    self.event_send.send(deserialized.payload).await.ok();
186
+                } else {
187
+                    // We received a response to a call - find the call and forward
188
+                    // the reply.
189
+                    if let Some(call) = self.calls_in_progress.remove(&deserialized.id) {
190
+                        self.free_ids.free(deserialized.id);
191
+                        call.send(Ok(deserialized.payload)).ok();
192
+                    } else {
193
+                        error!(
194
+                            "Received a reply for an unknown call ({:?})!",
195
+                            deserialized.id
196
+                        );
197
+                    }
198
+                }
199
+            }
200
+            Err(e) => {
201
+                // We received an invalid packet. We cannot determine for which
202
+                // call the packet was, so we just ignore it. The corresponding
203
+                // call will freeze until the connection is dropped.
204
+                error!("Invalid packet received: {:?}", e);
205
+            }
206
+        };
181 207
     }
182 208
 }
183 209
 

Ładowanie…
Anuluj
Zapisz