浏览代码

Add first parts of the server side of the RPC protocol.

Mathias Gottschlag 4 年前
父节点
当前提交
74009e8cab
共有 8 个文件被更改,包括 283 次插入35 次删除
  1. 3
    4
      src/bin/rpc_integration_test.rs
  2. 14
    0
      src/bin/server.rs
  3. 0
    11
      src/network/client.rs
  4. 64
    0
      src/network/integration_test.rs
  5. 1
    0
      src/network/mod.rs
  6. 179
    18
      src/network/server.rs
  7. 1
    2
      src/network/websocket.rs
  8. 21
    0
      tests/rpc.rs

src/bin/low_level_integration_test.rs → src/bin/rpc_integration_test.rs 查看文件

@@ -1,8 +1,7 @@
1 1
 use structopt::StructOpt;
2 2
 use tokio::signal;
3 3
 
4
-use twfss::network::client::low_level_integration_test_client;
5
-use twfss::network::server::LowLevelIntegrationTestServer;
4
+use twfss::network::integration_test::{rpc_integration_test_client, RpcIntegrationTestServer};
6 5
 
7 6
 #[derive(Debug, StructOpt)]
8 7
 #[structopt(about = "Integration tester for the low-level packet protocol.")]
@@ -15,7 +14,7 @@ enum Options {
15 14
 async fn main() {
16 15
     match Options::from_args() {
17 16
         Options::Server { bind_address } => {
18
-            let server = LowLevelIntegrationTestServer::start(&bind_address).await;
17
+            let server = RpcIntegrationTestServer::start(&bind_address).await;
19 18
             println!("Listening at {}...", bind_address);
20 19
             signal::ctrl_c().await.unwrap();
21 20
             println!("Ctrl+C received, stopping...");
@@ -23,7 +22,7 @@ async fn main() {
23 22
         }
24 23
         Options::Client { connect_address } => {
25 24
             println!("Connecting to {}...", connect_address);
26
-            low_level_integration_test_client(connect_address).await;
25
+            rpc_integration_test_client(&connect_address).await;
27 26
         }
28 27
     }
29 28
 }

+ 14
- 0
src/bin/server.rs 查看文件

@@ -9,6 +9,8 @@ use log::*;
9 9
 use tokio::net::{TcpListener, TcpStream};
10 10
 use tokio::stream::StreamExt;
11 11
 
12
+use twfss::network::server::{RPCResponse, RPCServer};
13
+use twfss::protocol::PacketPayload;
12 14
 use twfss::{paths, Database};
13 15
 
14 16
 #[tokio::main]
@@ -63,3 +65,15 @@ async fn accept_connection(peer: SocketAddr, stream: TcpStream) {
63 65
         }
64 66
     }
65 67
 }
68
+
69
+struct ServerConnection {
70
+    // TODO
71
+}
72
+
73
+impl RPCServer for ServerConnection {
74
+    type PacketType = PacketPayload;
75
+
76
+    fn incoming_call(&mut self, call: Self::PacketType, response: RPCResponse<Self::PacketType>) {
77
+        // TODO
78
+    }
79
+}

+ 0
- 11
src/network/client.rs 查看文件

@@ -435,14 +435,3 @@ mod tests {
435 435
     // - Test whether the client correctly closes the connection and whether futures return an error
436 436
     // in this case.
437 437
 }
438
-
439
-pub async fn low_level_integration_test_client(_address: String) {
440
-    /*use super::websocket::WebSocketConnection;
441
-
442
-    let url = Url::parse(&address).unwrap();
443
-    let (mut ws_stream, _) = connect_async(url).await.unwrap();
444
-    let mut conn = WebSocketConnection::new(ws_stream);
445
-    let (rpc_interface, server_events) = create::<_, _, u32>(conn);*/
446
-    // TODO
447
-    panic!("Not yet implemented.");
448
-}

+ 64
- 0
src/network/integration_test.rs 查看文件

@@ -0,0 +1,64 @@
1
+use std::net::SocketAddr;
2
+
3
+use async_tungstenite::accept_async;
4
+use async_tungstenite::tokio::TokioAdapter;
5
+use futures::SinkExt;
6
+use log::*;
7
+use tokio::net::{TcpListener, TcpStream};
8
+use tokio::stream::StreamExt;
9
+
10
+pub async fn rpc_integration_test_client(_address: &str) {
11
+    /*use super::websocket::WebSocketConnection;
12
+
13
+    let url = Url::parse(&address).unwrap();
14
+    let (mut ws_stream, _) = connect_async(url).await.unwrap();
15
+    let mut conn = WebSocketConnection::new(ws_stream);
16
+    let (rpc_interface, server_events) = create::<_, _, u32>(conn);*/
17
+    // TODO
18
+    panic!("Not yet implemented.");
19
+}
20
+
21
+pub struct RpcIntegrationTestServer {
22
+    // TODO
23
+}
24
+
25
+impl RpcIntegrationTestServer {
26
+    pub async fn start(_bind_address: &str) -> RpcIntegrationTestServer {
27
+        // TODO
28
+        let addr = "127.0.0.1:12345";
29
+        let mut listener = TcpListener::bind(&addr)
30
+            .await
31
+            .expect("cannot bind to server port");
32
+        let mut incoming = listener.incoming();
33
+        info!("Listening on {}.", addr);
34
+
35
+        while let Some(Ok(stream)) = incoming.next().await {
36
+            let peer = stream
37
+                .peer_addr()
38
+                .expect("connected streams should have a peer address");
39
+
40
+            tokio::spawn(Self::accept_connection(peer, stream));
41
+        }
42
+        RpcIntegrationTestServer {}
43
+    }
44
+
45
+    pub async fn stop(self) {
46
+        // TODO
47
+    }
48
+
49
+    async fn accept_connection(peer: SocketAddr, stream: TcpStream) {
50
+        let mut ws_stream = accept_async(TokioAdapter(stream))
51
+            .await
52
+            .expect("Failed to accept");
53
+
54
+        info!("New WebSocket connection from {}.", peer);
55
+
56
+        while let Some(msg) = ws_stream.next().await {
57
+            let msg = msg.expect("Failed to get request");
58
+            // TODO
59
+            if msg.is_text() || msg.is_binary() {
60
+                ws_stream.send(msg).await.expect("Failed to send response");
61
+            }
62
+        }
63
+    }
64
+}

+ 1
- 0
src/network/mod.rs 查看文件

@@ -29,6 +29,7 @@
29 29
 
30 30
 pub mod client;
31 31
 pub mod idalloc;
32
+pub mod integration_test;
32 33
 pub mod server;
33 34
 pub mod websocket;
34 35
 

+ 179
- 18
src/network/server.rs 查看文件

@@ -1,33 +1,194 @@
1
-// Trait that is implemented by any RPC server to process incoming calls.
2
-trait RPCServer {
1
+use std::fmt;
2
+
3
+use futures::sink::{Sink, SinkExt};
4
+use futures::stream::{Stream, StreamExt};
5
+use log::{error, info};
6
+use rmps::Serializer;
7
+use serde::de::DeserializeOwned;
8
+use serde::Serialize;
9
+use tokio::sync::mpsc;
10
+
11
+use super::{Packet, SERVER_EVENT_ID};
12
+
13
+/// Trait that is implemented by any RPC server to process incoming calls.
14
+pub trait RPCServer {
3 15
     type PacketType;
4
-    type NetworkError;
5 16
 
6
-    fn incoming_call<Response>(&mut self, call: Self::PacketType, response: Response)
7
-    where
8
-        Response: RPCResponse<PacketType = Self::PacketType, NetworkError = Self::NetworkError>;
17
+    fn incoming_call(&mut self, call: Self::PacketType, response: RPCResponse<Self::PacketType>);
9 18
 }
10 19
 
11 20
 // TODO: Document that types implementing this trait should also implement Drop and should return
12 21
 // an error if the type is dropped without send() being called.
13
-trait RPCResponse {
14
-    type PacketType;
15
-    type NetworkError;
22
+pub struct RPCResponse<PacketType> {
23
+    send_packet: mpsc::UnboundedSender<Packet<PacketType>>,
24
+    id: u32,
25
+    error_value: Option<PacketType>,
26
+}
16 27
 
17
-    fn send(self, packet: &Self::PacketType) -> Result<(), Self::NetworkError>;
28
+impl<PacketType> RPCResponse<PacketType>
29
+where
30
+    PacketType: Serialize,
31
+{
32
+    pub fn send(self, packet: PacketType) {
33
+        // We ignore errors here, as they mean that the connection was just closed.
34
+        self.send_packet
35
+            .send(Packet {
36
+                payload: packet,
37
+                id: self.id,
38
+            })
39
+            .ok();
40
+    }
18 41
 }
19 42
 
20
-pub struct LowLevelIntegrationTestServer {
21
-    // TODO
43
+impl<PacketType> Drop for RPCResponse<PacketType> {
44
+    fn drop(&mut self) {
45
+        // Send an error response?
46
+        self.send_packet
47
+            .send(Packet {
48
+                payload: self.error_value.take().unwrap(),
49
+                id: self.id,
50
+            })
51
+            .ok();
52
+    }
22 53
 }
23 54
 
24
-impl LowLevelIntegrationTestServer {
25
-    pub async fn start(_bind_address: &str) -> LowLevelIntegrationTestServer {
26
-        // TODO
27
-        LowLevelIntegrationTestServer {}
55
+/// Wraps an `RPCServer` and adds the network protocol implementation around it.
56
+pub struct RPCServerWrapper<Server: RPCServer, Connection> {
57
+    server: Server,
58
+    connection: Connection,
59
+    server_events: ServerEventStream<Server::PacketType>,
60
+    send_packet: mpsc::UnboundedSender<Packet<Server::PacketType>>,
61
+    sent_packets: mpsc::UnboundedReceiver<Packet<Server::PacketType>>,
62
+    error_value: Server::PacketType,
63
+}
64
+
65
+impl<Server: RPCServer, Connection, ConnectionError> RPCServerWrapper<Server, Connection>
66
+where
67
+    Server::PacketType: DeserializeOwned + Serialize + Clone + Send + 'static,
68
+    Connection: Stream<Item = Vec<u8>> + Sink<Vec<u8>, Error = ConnectionError> + Unpin,
69
+    ConnectionError: fmt::Debug,
70
+{
71
+    pub fn new(
72
+        server: Server,
73
+        connection: Connection,
74
+        server_events: ServerEventStream<Server::PacketType>,
75
+        error_value: Server::PacketType,
76
+    ) -> Self {
77
+        let (send_packet, sent_packets) = mpsc::unbounded_channel();
78
+        Self {
79
+            server,
80
+            connection,
81
+            server_events,
82
+            send_packet,
83
+            sent_packets,
84
+            error_value,
85
+        }
28 86
     }
29 87
 
30
-    pub async fn stop(self) {
31
-        // TODO
88
+    pub async fn run(&mut self) {
89
+        loop {
90
+            tokio::select! {
91
+                packet = &mut self.connection.next() => {
92
+                    match packet {
93
+                        Some(data) => {
94
+                            if !self.packet_received(data) {
95
+                                // Disconnect on errors.
96
+                                break;
97
+                            }
98
+                        }
99
+                        None => {
100
+                            // Stream closed.
101
+                            println!("Disconnected.");
102
+                            break;
103
+                        }
104
+                    }
105
+                }
106
+                cmd = &mut self.sent_packets.next() => {
107
+                    match cmd {
108
+                        Some(packet) => {
109
+                            if !self.send_packet(packet).await {
110
+                                // Disconnect on errors.
111
+                                break;
112
+                            }
113
+                        }
114
+                        None => {
115
+                            // This should never happen.
116
+                            break;
117
+                        }
118
+                    }
119
+                }
120
+                cmd = &mut self.server_events.receive.next() => {
121
+                    match cmd {
122
+                        Some(packet) => {
123
+                            if !self.send_packet(Packet{payload: packet, id: SERVER_EVENT_ID}).await {
124
+                                // Disconnect on errors.
125
+                                break;
126
+                            }
127
+                        }
128
+                        None => {
129
+                            // This should never happen.
130
+                            break;
131
+                        }
132
+                    }
133
+                }
134
+            }
135
+        }
32 136
     }
137
+
138
+    fn packet_received(&mut self, packet: Vec<u8>) -> bool {
139
+        match rmps::decode::from_slice(&packet) as Result<Packet<Server::PacketType>, _> {
140
+            Ok(deserialized) => {
141
+                if deserialized.id == SERVER_EVENT_ID {
142
+                    // Clients must never send server events.
143
+                    info!("Client sent SERVER_EVENT_ID!");
144
+                    return false;
145
+                } else {
146
+                    let response = RPCResponse {
147
+                        send_packet: self.send_packet.clone(),
148
+                        id: deserialized.id,
149
+                        error_value: Some(self.error_value.clone()),
150
+                    };
151
+                    self.server.incoming_call(deserialized.payload, response);
152
+                }
153
+            }
154
+            Err(e) => {
155
+                error!("Invalid packet received: {:?}", e);
156
+                return false;
157
+            }
158
+        };
159
+        return true;
160
+    }
161
+
162
+    async fn send_packet(&mut self, packet: Packet<Server::PacketType>) -> bool {
163
+        let mut serialized = Vec::new();
164
+        packet
165
+            .serialize(
166
+                &mut Serializer::new(&mut serialized)
167
+                    .with_struct_map()
168
+                    .with_string_variants(),
169
+            )
170
+            .unwrap();
171
+        match self.connection.send(serialized).await {
172
+            Ok(()) => true,
173
+            Err(e) => {
174
+                info!("Could not send packet: {:?}", e);
175
+                false
176
+            }
177
+        }
178
+    }
179
+}
180
+
181
+pub fn server_events<PacketType>() -> (ServerEventSink<PacketType>, ServerEventStream<PacketType>) {
182
+    let (send, receive) = mpsc::unbounded_channel();
183
+    (ServerEventSink { send }, ServerEventStream { receive })
184
+}
185
+
186
+/// Stream of asynchronous packets sent to the clients.
187
+#[derive(Clone)]
188
+pub struct ServerEventSink<PacketType> {
189
+    send: mpsc::UnboundedSender<PacketType>,
190
+}
191
+
192
+pub struct ServerEventStream<PacketType> {
193
+    receive: mpsc::UnboundedReceiver<PacketType>,
33 194
 }

+ 1
- 2
src/network/websocket.rs 查看文件

@@ -5,8 +5,7 @@ use async_tungstenite::WebSocketStream;
5 5
 use futures::prelude::*;
6 6
 use futures::sink::Sink;
7 7
 use futures::stream::Stream;
8
-use futures::task::Context;
9
-use futures::task::Poll;
8
+use futures::task::{Context, Poll};
10 9
 use pin_project::pin_project;
11 10
 
12 11
 #[pin_project]

+ 21
- 0
tests/rpc.rs 查看文件

@@ -0,0 +1,21 @@
1
+use std::net::SocketAddr;
2
+
3
+use async_tungstenite::accept_async;
4
+use async_tungstenite::tokio::TokioAdapter;
5
+use futures::SinkExt;
6
+use log::*;
7
+use tokio::net::{TcpListener, TcpStream};
8
+use tokio::stream::StreamExt;
9
+
10
+use twfss::network::integration_test::{rpc_integration_test_client, RpcIntegrationTestServer};
11
+
12
+#[tokio::test]
13
+async fn test_rpc() {
14
+    /*const ADDRESS: &str = "localhost:12345";
15
+    let server = RpcIntegrationTestServer::start(ADDRESS).await;
16
+
17
+    // TODO: Error handling, usable asserts.
18
+    rpc_integration_test_client(&ADDRESS).await;
19
+
20
+    server.stop().await;*/
21
+}

正在加载...
取消
保存