aboutsummaryrefslogtreecommitdiff
path: root/src/protocol/message
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol/message')
-rw-r--r--src/protocol/message/handshake.rs94
-rw-r--r--src/protocol/message/handshake/types.rs44
2 files changed, 89 insertions, 49 deletions
diff --git a/src/protocol/message/handshake.rs b/src/protocol/message/handshake.rs
index b346d8e..7713ed9 100644
--- a/src/protocol/message/handshake.rs
+++ b/src/protocol/message/handshake.rs
@@ -1,5 +1,7 @@
+use std::result::Result;
+
+use crate::protocol::error::ErrorKind;
use crate::protocol::primitive::{String, StringList, Variant, VariantList};
-use crate::protocol::primitive::{serialize, deserialize, qread};
mod types;
pub use types::{VariantMap, HandshakeDeserialize, HandshakeSerialize, HandshakeQRead};
@@ -14,7 +16,7 @@ pub struct ClientInit {
}
impl HandshakeSerialize for ClientInit {
- fn serialize(&self) -> Vec<u8> {
+ fn serialize(&self) -> Result<Vec<u8>, ErrorKind> {
let mut values: VariantMap = VariantMap::with_capacity(5);
values.insert("MsgType".to_string(), Variant::String("ClientInit".to_string()));
values.insert("ClientVersion".to_string(), Variant::String(self.client_version.clone()));
@@ -26,15 +28,27 @@ impl HandshakeSerialize for ClientInit {
}
impl HandshakeDeserialize for ClientInit {
- fn parse(b: &[u8]) -> (usize, Self) {
- let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b);
-
- return (len, Self {
- client_version: match_variant!(values, Variant::String, "ClientVersion"),
- client_date: match_variant!(values, Variant::String, "ClientDate"),
- feature_list: match_variant!(values, Variant::StringList, "FeatureList"),
- client_features: match_variant!(values, Variant::u32, "Features")
- });
+ fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> {
+ let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?;
+
+ let msgtypev = &values["MsgType"];
+ let msgtype;
+ match msgtypev {
+ Variant::String(x) => msgtype = x,
+ Variant::StringUTF8(x) => msgtype = x,
+ _ => return Err(ErrorKind::WrongVariant)
+ };
+
+ if msgtype == "ClientInit" {
+ return Ok((len, Self {
+ client_version: match_variant!(values, Variant::String, "ClientVersion"),
+ client_date: match_variant!(values, Variant::String, "ClientDate"),
+ feature_list: match_variant!(values, Variant::StringList, "FeatureList"),
+ client_features: match_variant!(values, Variant::u32, "Features")
+ }));
+ } else {
+ return Err(ErrorKind::WrongMsgType);
+ }
}
}
@@ -44,7 +58,7 @@ pub struct ClientInitReject {
}
impl HandshakeSerialize for ClientInitReject {
- fn serialize(&self) -> Vec<u8> {
+ fn serialize(&self) -> Result<Vec<u8>, ErrorKind> {
let mut values: VariantMap = VariantMap::with_capacity(2);
values.insert("MsgType".to_string(), Variant::String("ClientInitReject".to_string()));
values.insert("ErrorString".to_string(), Variant::String(self.error_string.clone()));
@@ -53,12 +67,24 @@ impl HandshakeSerialize for ClientInitReject {
}
impl HandshakeDeserialize for ClientInitReject {
- fn parse(b: &[u8]) -> (usize, Self) {
- let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b);
+ fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> {
+ let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?;
- return (len, Self {
- error_string: match_variant!(values, Variant::String, "ErrorString")
- });
+ let msgtypev = &values["MsgType"];
+ let msgtype;
+ match msgtypev {
+ Variant::String(x) => msgtype = x,
+ Variant::StringUTF8(x) => msgtype = x,
+ _ => return Err(ErrorKind::WrongVariant)
+ };
+
+ if msgtype == "ClientInitReject" {
+ return Ok((len, Self {
+ error_string: match_variant!(values, Variant::String, "ErrorString")
+ }));
+ } else {
+ return Err(ErrorKind::WrongMsgType);
+ }
}
}
@@ -72,7 +98,7 @@ pub struct ClientInitAck {
}
impl HandshakeSerialize for ClientInitAck {
- fn serialize(&self) -> Vec<u8> {
+ fn serialize(&self) -> Result<Vec<u8>, ErrorKind> {
let mut values: VariantMap = VariantMap::with_capacity(2);
values.insert("MsgType".to_string(), Variant::String("ClientInitAck".to_string()));
values.insert("CoreFeatures".to_string(), Variant::u32(self.core_features));
@@ -85,15 +111,27 @@ impl HandshakeSerialize for ClientInitAck {
}
impl HandshakeDeserialize for ClientInitAck {
- fn parse(b: &[u8]) -> (usize, Self) {
- let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b);
-
- return (len, Self {
- core_features: 0x00008000,
- core_configured: match_variant!(values, Variant::bool, "Configured"),
- storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"),
- authenticators: match_variant!(values, Variant::VariantList, "Authenticators"),
- feature_list: match_variant!(values, Variant::StringList, "FeatureList")
- });
+ fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> {
+ let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?;
+
+ let msgtypev = &values["MsgType"];
+ let msgtype;
+ match msgtypev {
+ Variant::String(x) => msgtype = x,
+ Variant::StringUTF8(x) => msgtype = x,
+ _ => return Err(ErrorKind::WrongVariant)
+ };
+
+ if msgtype == "ClientInitAck" {
+ return Ok((len, Self {
+ core_features: 0x00008000,
+ core_configured: match_variant!(values, Variant::bool, "Configured"),
+ storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"),
+ authenticators: match_variant!(values, Variant::VariantList, "Authenticators"),
+ feature_list: match_variant!(values, Variant::StringList, "FeatureList")
+ }));
+ } else {
+ return Err(ErrorKind::WrongMsgType);
+ }
}
}
diff --git a/src/protocol/message/handshake/types.rs b/src/protocol/message/handshake/types.rs
index dadd058..be290e9 100644
--- a/src/protocol/message/handshake/types.rs
+++ b/src/protocol/message/handshake/types.rs
@@ -1,5 +1,6 @@
use std::io::Read;
use std::vec::Vec;
+use std::result::Result;
use std::convert::TryInto;
use std::collections::HashMap;
@@ -8,82 +9,83 @@ use crate::protocol::primitive::{String, Variant};
use crate::protocol::primitive::serialize::Serialize;
use crate::protocol::primitive::deserialize::Deserialize;
use crate::protocol::primitive::qread::QRead;
+use crate::protocol::error::ErrorKind;
pub trait HandshakeSerialize {
- fn serialize(&self) -> Vec<u8>;
+ fn serialize(&self) -> Result<Vec<u8>, ErrorKind>;
}
pub trait HandshakeDeserialize {
- fn parse(b: &[u8]) -> (usize, Self);
+ fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> where Self: std::marker::Sized ;
}
pub trait HandshakeQRead {
- fn read<T: Read>(stream: &mut T, buf: &mut [u8]) -> usize;
+ fn read<T: Read>(stream: &mut T, buf: &mut [u8]) -> Result<usize, ErrorKind>;
}
pub type VariantMap = HashMap<String, Variant>;
impl HandshakeSerialize for VariantMap {
- fn serialize<'a>(&'a self) -> Vec<u8> {
+ fn serialize<'a>(&'a self) -> Result<Vec<u8>, ErrorKind> {
let mut res: Vec<u8> = Vec::new();
for (k, v) in self {
let key = Variant::String(k.clone());
- res.extend(key.serialize());
- res.extend(v.serialize());
+ res.extend(key.serialize()?);
+ res.extend(v.serialize()?);
}
util::insert_bytes(0, &mut res, &mut [0, 0, 0, 10]);
- let len: i32 = res.len().try_into().unwrap();
+ let len: i32 = res.len().try_into()?;
util::insert_bytes(0, &mut res, &mut ((len).to_be_bytes()));
- return res;
+ return Ok(res);
}
}
impl HandshakeDeserialize for VariantMap {
- fn parse(b: &[u8]) -> (usize, Self) {
- let (_, len) = i32::parse(&b[0..4]);
+ fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> {
+ let (_, len) = i32::parse(&b[0..4])?;
let mut pos: usize = 8;
let mut map = VariantMap::new();
let ulen: usize = len as usize;
loop {
if (pos) >= ulen { break; }
- let (nlen, name) = Variant::parse(&b[pos..]);
+ let (nlen, name) = Variant::parse(&b[pos..])?;
pos += nlen;
- let (vlen, value) = Variant::parse(&b[pos..]);
+ let (vlen, value) = Variant::parse(&b[pos..])?;
pos += vlen;
match name {
Variant::String(x) => map.insert(x, value),
Variant::StringUTF8(x) => map.insert(x, value),
- _ => panic!()
+ _ => return Err(ErrorKind::WrongVariant)
};
}
- return (pos, map);
+ return Ok((pos, map));
}
}
impl HandshakeQRead for VariantMap {
- fn read<T: Read>(s: &mut T, b: &mut [u8]) -> usize {
- s.read(&mut b[0..4]).unwrap();
- let (_, len) = i32::parse(&b[0..4]);
+ fn read<T: Read>(s: &mut T, b: &mut [u8]) -> Result<usize, ErrorKind> {
+ s.read(&mut b[0..4])?;
+ let (_, len) = i32::parse(&b[0..4])?;
// Read the 00 00 00 0a VariantType bytes and discard
- s.read(&mut b[4..8]).unwrap();
+ s.read(&mut b[4..8])?;
let mut pos = 8;
let len: usize = len as usize;
loop {
if pos >= (len - 4) { break; }
- pos += Variant::read(s, &mut b[pos..]);
- pos += Variant::read(s, &mut b[pos..]);
+ pos += Variant::read(s, &mut b[pos..])?;
+ pos += Variant::read(s, &mut b[pos..])?;
}
- return pos;
+ return Ok(pos);
}
}