|
|
@@ -1,9 +1,9 @@
|
|
1
|
1
|
use std::collections::BTreeMap;
|
|
2
|
|
-use std::fmt;
|
|
3
|
2
|
use std::time::Duration;
|
|
|
3
|
+use std::{fmt, mem};
|
|
4
|
4
|
|
|
5
|
5
|
use core::pin::Pin;
|
|
6
|
|
-use futures::future::{self, Either, Future};
|
|
|
6
|
+use futures::future::Future;
|
|
7
|
7
|
use futures::sink::{Sink, SinkExt};
|
|
8
|
8
|
use futures::stream::{Stream, StreamExt};
|
|
9
|
9
|
use futures::task::{Context, Poll};
|
|
|
@@ -37,7 +37,8 @@ where
|
|
37
|
37
|
let (cmd_send, cmd_receive) = mpsc::unbounded_channel();
|
|
38
|
38
|
let (event_send, event_receive) = mpsc::channel(1);
|
|
39
|
39
|
tokio::spawn(async move {
|
|
40
|
|
- client_connection_task(stream, cmd_receive, event_send).await;
|
|
|
40
|
+ let task = ClientConnectionTask::new(stream, cmd_receive, event_send);
|
|
|
41
|
+ task.run().await;
|
|
41
|
42
|
});
|
|
42
|
43
|
|
|
43
|
44
|
(
|
|
|
@@ -53,131 +54,156 @@ enum ClientTaskCommand<PacketType, ErrorType> {
|
|
53
|
54
|
SendPacket(CallData<PacketType, ErrorType>),
|
|
54
|
55
|
}
|
|
55
|
56
|
|
|
56
|
|
-async fn client_connection_task<StreamType, ErrorType, PacketType>(
|
|
57
|
|
- mut stream: StreamType,
|
|
58
|
|
- mut cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
|
|
59
|
|
- mut event_send: mpsc::Sender<PacketType>,
|
|
60
|
|
-) where
|
|
|
57
|
+struct ClientConnectionTask<StreamType, ErrorType, PacketType> {
|
|
|
58
|
+ free_ids: IdAllocator,
|
|
|
59
|
+ calls_in_progress: BTreeMap<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>,
|
|
|
60
|
+ stream: StreamType,
|
|
|
61
|
+ cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
|
|
|
62
|
+ event_send: mpsc::Sender<PacketType>,
|
|
|
63
|
+}
|
|
|
64
|
+
|
|
|
65
|
+impl<StreamType, ErrorType, PacketType> ClientConnectionTask<StreamType, ErrorType, PacketType>
|
|
|
66
|
+where
|
|
61
|
67
|
PacketType: DeserializeOwned + Serialize,
|
|
62
|
68
|
StreamType:
|
|
63
|
69
|
Stream<Item = Result<Vec<u8>, ErrorType>> + Sink<Vec<u8>, Error = ErrorType> + Unpin,
|
|
64
|
70
|
ErrorType: fmt::Debug,
|
|
65
|
71
|
{
|
|
66
|
|
- let mut free_ids = IdAllocator::new();
|
|
67
|
|
- let mut calls_in_progress =
|
|
68
|
|
- BTreeMap::<u32, oneshot::Sender<Result<PacketType, RPCError<ErrorType>>>>::new();
|
|
69
|
|
-
|
|
70
|
|
- // TODO: Refactor this function.
|
|
|
72
|
+ fn new(
|
|
|
73
|
+ stream: StreamType,
|
|
|
74
|
+ cmd_receive: mpsc::UnboundedReceiver<ClientTaskCommand<PacketType, ErrorType>>,
|
|
|
75
|
+ event_send: mpsc::Sender<PacketType>,
|
|
|
76
|
+ ) -> Self {
|
|
|
77
|
+ Self {
|
|
|
78
|
+ free_ids: IdAllocator::new(),
|
|
|
79
|
+ calls_in_progress: BTreeMap::new(),
|
|
|
80
|
+ stream: stream,
|
|
|
81
|
+ cmd_receive: cmd_receive,
|
|
|
82
|
+ event_send,
|
|
|
83
|
+ }
|
|
|
84
|
+ }
|
|
71
|
85
|
|
|
72
|
|
- let mut connection_closed = false;
|
|
73
|
|
- loop {
|
|
74
|
|
- match future::select(Box::pin(stream.next()), Box::pin(cmd_receive.recv())).await {
|
|
75
|
|
- Either::Left((packet, _cmd_future)) => {
|
|
76
|
|
- match packet {
|
|
77
|
|
- Some(Ok(data)) => {
|
|
78
|
|
- match rmps::decode::from_slice(&data) as Result<Packet<PacketType>, _> {
|
|
79
|
|
- Ok(deserialized) => {
|
|
80
|
|
- if deserialized.id == SERVER_EVENT_ID {
|
|
81
|
|
- // We received an asynchronous event from the server.
|
|
82
|
|
- event_send.send(deserialized.payload).await.ok();
|
|
83
|
|
- } else {
|
|
84
|
|
- // We received a response to a call - find the call and forward
|
|
85
|
|
- // the reply.
|
|
86
|
|
- if let Some(call) = calls_in_progress.remove(&deserialized.id) {
|
|
87
|
|
- free_ids.free(deserialized.id);
|
|
88
|
|
- call.send(Ok(deserialized.payload)).ok();
|
|
89
|
|
- } else {
|
|
90
|
|
- error!(
|
|
91
|
|
- "Received a reply for an unknown call ({:?})!",
|
|
92
|
|
- deserialized.id
|
|
93
|
|
- );
|
|
94
|
|
- }
|
|
95
|
|
- }
|
|
96
|
|
- }
|
|
97
|
|
- Err(e) => {
|
|
98
|
|
- // We received an invalid packet. We cannot determine for which
|
|
99
|
|
- // call the packet was, so we just ignore it. The corresponding
|
|
100
|
|
- // call will freeze until the connection is dropped.
|
|
101
|
|
- error!("Invalid packet received: {:?}", e);
|
|
102
|
|
- }
|
|
103
|
|
- };
|
|
104
|
|
- }
|
|
105
|
|
- Some(Err(e)) => {
|
|
106
|
|
- error!("Network error: {:?}", e);
|
|
107
|
|
- }
|
|
108
|
|
- None => {
|
|
109
|
|
- // Stream closed. We still need to return errors for any
|
|
110
|
|
- // packet sent by the client, so do not exit the loop
|
|
111
|
|
- // immediately but rather return an error for any
|
|
112
|
|
- // incoming packet.
|
|
113
|
|
- connection_closed = true;
|
|
114
|
|
- break;
|
|
|
86
|
+ async fn run(mut self) {
|
|
|
87
|
+ let mut connection_closed = false;
|
|
|
88
|
+ loop {
|
|
|
89
|
+ tokio::select! {
|
|
|
90
|
+ packet = &mut self.stream.next() => {
|
|
|
91
|
+ match packet {
|
|
|
92
|
+ Some(Ok(data)) => {
|
|
|
93
|
+ self.received_packet(&data).await;
|
|
|
94
|
+ }
|
|
|
95
|
+ Some(Err(e)) => {
|
|
|
96
|
+ error!("Network error: {:?}", e);
|
|
|
97
|
+ }
|
|
|
98
|
+ None => {
|
|
|
99
|
+ // Stream closed. We still need to return errors for any
|
|
|
100
|
+ // packet sent by the client, so do not exit the loop
|
|
|
101
|
+ // immediately but rather return an error for any
|
|
|
102
|
+ // incoming packet.
|
|
|
103
|
+ connection_closed = true;
|
|
|
104
|
+ break;
|
|
|
105
|
+ }
|
|
115
|
106
|
}
|
|
116
|
107
|
}
|
|
117
|
|
- }
|
|
118
|
|
- Either::Right((cmd, _stream_future)) => {
|
|
119
|
|
- match cmd {
|
|
120
|
|
- Some(ClientTaskCommand::Close) => {
|
|
121
|
|
- // The RPC interface for this connection was dropped, so
|
|
122
|
|
- // we just exit the loop and close the connection.
|
|
123
|
|
- break;
|
|
124
|
|
- }
|
|
125
|
|
- Some(ClientTaskCommand::SendPacket(call)) => {
|
|
126
|
|
- // Add an ID for the outgoing packet.
|
|
127
|
|
- let id = free_ids.alloc();
|
|
128
|
|
- let packet = Packet {
|
|
129
|
|
- id: id,
|
|
130
|
|
- payload: call.data,
|
|
131
|
|
- };
|
|
132
|
|
- let mut serialized = Vec::new();
|
|
133
|
|
- packet
|
|
134
|
|
- .serialize(
|
|
135
|
|
- &mut Serializer::new(&mut serialized)
|
|
136
|
|
- .with_struct_map()
|
|
137
|
|
- .with_string_variants(),
|
|
138
|
|
- )
|
|
139
|
|
- .unwrap();
|
|
140
|
|
-
|
|
141
|
|
- // Send the serialized packet.
|
|
142
|
|
- // TODO: Send without flushing?
|
|
143
|
|
- match stream.send(serialized).await {
|
|
144
|
|
- Ok(()) => {
|
|
145
|
|
- calls_in_progress.insert(id, call.result);
|
|
146
|
|
- }
|
|
147
|
|
- Err(e) => {
|
|
148
|
|
- // We potentially leak the ID here, but that is
|
|
149
|
|
- // okay as the state of the call is not well
|
|
150
|
|
- // defined anymore.
|
|
151
|
|
- call.result.send(Err(RPCError::Stream(e))).ok();
|
|
152
|
|
- }
|
|
|
108
|
+ cmd = &mut self.cmd_receive.next() => {
|
|
|
109
|
+ match cmd {
|
|
|
110
|
+ Some(ClientTaskCommand::Close) => {
|
|
|
111
|
+ // The RPC interface for this connection was dropped, so
|
|
|
112
|
+ // we just exit the loop and close the connection.
|
|
|
113
|
+ break;
|
|
|
114
|
+ }
|
|
|
115
|
+ Some(ClientTaskCommand::SendPacket(call)) => {
|
|
|
116
|
+ self.send_packet(call).await;
|
|
|
117
|
+ }
|
|
|
118
|
+ None => {
|
|
|
119
|
+ // This should never happen.
|
|
|
120
|
+ break;
|
|
153
|
121
|
}
|
|
154
|
|
- }
|
|
155
|
|
- None => {
|
|
156
|
|
- // This should never happen.
|
|
157
|
|
- break;
|
|
158
|
122
|
}
|
|
159
|
123
|
}
|
|
160
|
124
|
}
|
|
161
|
|
- };
|
|
|
125
|
+ }
|
|
|
126
|
+
|
|
|
127
|
+ drop(self.event_send);
|
|
|
128
|
+
|
|
|
129
|
+ // Return errors for all pending packets as there will be no responses.
|
|
|
130
|
+ let mut calls_in_progress = BTreeMap::new();
|
|
|
131
|
+ mem::swap(&mut calls_in_progress, &mut self.calls_in_progress);
|
|
|
132
|
+ for (_, result) in calls_in_progress {
|
|
|
133
|
+ result.send(Err(RPCError::Closed)).ok();
|
|
|
134
|
+ }
|
|
|
135
|
+
|
|
|
136
|
+ if connection_closed {
|
|
|
137
|
+ // The stream was closed, but the command channel was not. We still need
|
|
|
138
|
+ // to return errors for any packet sent by the client, so do not exit
|
|
|
139
|
+ // the loop immediately but rather return an error for any incoming
|
|
|
140
|
+ // packet.
|
|
|
141
|
+ client_connection_closed(self.cmd_receive).await;
|
|
|
142
|
+ } else {
|
|
|
143
|
+ // The command channel was closed, but the stream was not. Gracefully
|
|
|
144
|
+ // close the stream.
|
|
|
145
|
+ // TODO
|
|
|
146
|
+ }
|
|
162
|
147
|
}
|
|
163
|
148
|
|
|
164
|
|
- drop(event_send);
|
|
|
149
|
+ async fn send_packet(&mut self, call: CallData<PacketType, ErrorType>) {
|
|
|
150
|
+ // Add an ID for the outgoing packet.
|
|
|
151
|
+ let id = self.free_ids.alloc();
|
|
|
152
|
+ let packet = Packet {
|
|
|
153
|
+ id: id,
|
|
|
154
|
+ payload: call.data,
|
|
|
155
|
+ };
|
|
|
156
|
+ let mut serialized = Vec::new();
|
|
|
157
|
+ packet
|
|
|
158
|
+ .serialize(
|
|
|
159
|
+ &mut Serializer::new(&mut serialized)
|
|
|
160
|
+ .with_struct_map()
|
|
|
161
|
+ .with_string_variants(),
|
|
|
162
|
+ )
|
|
|
163
|
+ .unwrap();
|
|
165
|
164
|
|
|
166
|
|
- // Return errors for all pending packets as there will be no responses.
|
|
167
|
|
- for (_, result) in calls_in_progress {
|
|
168
|
|
- result.send(Err(RPCError::Closed)).ok();
|
|
|
165
|
+ // Send the serialized packet.
|
|
|
166
|
+ // TODO: Send without flushing?
|
|
|
167
|
+ match self.stream.send(serialized).await {
|
|
|
168
|
+ Ok(()) => {
|
|
|
169
|
+ self.calls_in_progress.insert(id, call.result);
|
|
|
170
|
+ }
|
|
|
171
|
+ Err(e) => {
|
|
|
172
|
+ // We potentially leak the ID here, but that is
|
|
|
173
|
+ // okay as the state of the call is not well
|
|
|
174
|
+ // defined anymore.
|
|
|
175
|
+ call.result.send(Err(RPCError::Stream(e))).ok();
|
|
|
176
|
+ }
|
|
|
177
|
+ }
|
|
169
|
178
|
}
|
|
170
|
179
|
|
|
171
|
|
- if connection_closed {
|
|
172
|
|
- // The stream was closed, but the command channel was not. We still need
|
|
173
|
|
- // to return errors for any packet sent by the client, so do not exit
|
|
174
|
|
- // the loop immediately but rather return an error for any incoming
|
|
175
|
|
- // packet.
|
|
176
|
|
- client_connection_closed(cmd_receive).await;
|
|
177
|
|
- } else {
|
|
178
|
|
- // The command channel was closed, but the stream was not. Gracefully
|
|
179
|
|
- // close the stream.
|
|
180
|
|
- // TODO
|
|
|
180
|
+ async fn received_packet(&mut self, data: &[u8]) {
|
|
|
181
|
+ match rmps::decode::from_slice(data) as Result<Packet<PacketType>, _> {
|
|
|
182
|
+ Ok(deserialized) => {
|
|
|
183
|
+ if deserialized.id == SERVER_EVENT_ID {
|
|
|
184
|
+ // We received an asynchronous event from the server.
|
|
|
185
|
+ self.event_send.send(deserialized.payload).await.ok();
|
|
|
186
|
+ } else {
|
|
|
187
|
+ // We received a response to a call - find the call and forward
|
|
|
188
|
+ // the reply.
|
|
|
189
|
+ if let Some(call) = self.calls_in_progress.remove(&deserialized.id) {
|
|
|
190
|
+ self.free_ids.free(deserialized.id);
|
|
|
191
|
+ call.send(Ok(deserialized.payload)).ok();
|
|
|
192
|
+ } else {
|
|
|
193
|
+ error!(
|
|
|
194
|
+ "Received a reply for an unknown call ({:?})!",
|
|
|
195
|
+ deserialized.id
|
|
|
196
|
+ );
|
|
|
197
|
+ }
|
|
|
198
|
+ }
|
|
|
199
|
+ }
|
|
|
200
|
+ Err(e) => {
|
|
|
201
|
+ // We received an invalid packet. We cannot determine for which
|
|
|
202
|
+ // call the packet was, so we just ignore it. The corresponding
|
|
|
203
|
+ // call will freeze until the connection is dropped.
|
|
|
204
|
+ error!("Invalid packet received: {:?}", e);
|
|
|
205
|
+ }
|
|
|
206
|
+ };
|
|
181
|
207
|
}
|
|
182
|
208
|
}
|
|
183
|
209
|
|