diff options
Diffstat (limited to 'src/protocol/message')
| -rw-r--r-- | src/protocol/message/handshake.rs | 94 | ||||
| -rw-r--r-- | src/protocol/message/handshake/types.rs | 44 |
2 files changed, 89 insertions, 49 deletions
diff --git a/src/protocol/message/handshake.rs b/src/protocol/message/handshake.rs index b346d8e..7713ed9 100644 --- a/src/protocol/message/handshake.rs +++ b/src/protocol/message/handshake.rs @@ -1,5 +1,7 @@ +use std::result::Result; + +use crate::protocol::error::ErrorKind; use crate::protocol::primitive::{String, StringList, Variant, VariantList}; -use crate::protocol::primitive::{serialize, deserialize, qread}; mod types; pub use types::{VariantMap, HandshakeDeserialize, HandshakeSerialize, HandshakeQRead}; @@ -14,7 +16,7 @@ pub struct ClientInit { } impl HandshakeSerialize for ClientInit { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(5); values.insert("MsgType".to_string(), Variant::String("ClientInit".to_string())); values.insert("ClientVersion".to_string(), Variant::String(self.client_version.clone())); @@ -26,15 +28,27 @@ impl HandshakeSerialize for ClientInit { } impl HandshakeDeserialize for ClientInit { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); - - return (len, Self { - client_version: match_variant!(values, Variant::String, "ClientVersion"), - client_date: match_variant!(values, Variant::String, "ClientDate"), - feature_list: match_variant!(values, Variant::StringList, "FeatureList"), - client_features: match_variant!(values, Variant::u32, "Features") - }); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; + + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInit" { + return Ok((len, Self { + client_version: match_variant!(values, Variant::String, "ClientVersion"), + client_date: match_variant!(values, Variant::String, "ClientDate"), + feature_list: match_variant!(values, Variant::StringList, "FeatureList"), + client_features: match_variant!(values, Variant::u32, "Features") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } @@ -44,7 +58,7 @@ pub struct ClientInitReject { } impl HandshakeSerialize for ClientInitReject { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(2); values.insert("MsgType".to_string(), Variant::String("ClientInitReject".to_string())); values.insert("ErrorString".to_string(), Variant::String(self.error_string.clone())); @@ -53,12 +67,24 @@ impl HandshakeSerialize for ClientInitReject { } impl HandshakeDeserialize for ClientInitReject { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; - return (len, Self { - error_string: match_variant!(values, Variant::String, "ErrorString") - }); + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInitReject" { + return Ok((len, Self { + error_string: match_variant!(values, Variant::String, "ErrorString") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } @@ -72,7 +98,7 @@ pub struct ClientInitAck { } impl HandshakeSerialize for ClientInitAck { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(2); values.insert("MsgType".to_string(), Variant::String("ClientInitAck".to_string())); values.insert("CoreFeatures".to_string(), Variant::u32(self.core_features)); @@ -85,15 +111,27 @@ impl HandshakeSerialize for ClientInitAck { } impl HandshakeDeserialize for ClientInitAck { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); - - return (len, Self { - core_features: 0x00008000, - core_configured: match_variant!(values, Variant::bool, "Configured"), - storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"), - authenticators: match_variant!(values, Variant::VariantList, "Authenticators"), - feature_list: match_variant!(values, Variant::StringList, "FeatureList") - }); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; + + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInitAck" { + return Ok((len, Self { + core_features: 0x00008000, + core_configured: match_variant!(values, Variant::bool, "Configured"), + storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"), + authenticators: match_variant!(values, Variant::VariantList, "Authenticators"), + feature_list: match_variant!(values, Variant::StringList, "FeatureList") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } diff --git a/src/protocol/message/handshake/types.rs b/src/protocol/message/handshake/types.rs index dadd058..be290e9 100644 --- a/src/protocol/message/handshake/types.rs +++ b/src/protocol/message/handshake/types.rs @@ -1,5 +1,6 @@ use std::io::Read; use std::vec::Vec; +use std::result::Result; use std::convert::TryInto; use std::collections::HashMap; @@ -8,82 +9,83 @@ use crate::protocol::primitive::{String, Variant}; use crate::protocol::primitive::serialize::Serialize; use crate::protocol::primitive::deserialize::Deserialize; use crate::protocol::primitive::qread::QRead; +use crate::protocol::error::ErrorKind; pub trait HandshakeSerialize { - fn serialize(&self) -> Vec<u8>; + fn serialize(&self) -> Result<Vec<u8>, ErrorKind>; } pub trait HandshakeDeserialize { - fn parse(b: &[u8]) -> (usize, Self); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> where Self: std::marker::Sized ; } pub trait HandshakeQRead { - fn read<T: Read>(stream: &mut T, buf: &mut [u8]) -> usize; + fn read<T: Read>(stream: &mut T, buf: &mut [u8]) -> Result<usize, ErrorKind>; } pub type VariantMap = HashMap<String, Variant>; impl HandshakeSerialize for VariantMap { - fn serialize<'a>(&'a self) -> Vec<u8> { + fn serialize<'a>(&'a self) -> Result<Vec<u8>, ErrorKind> { let mut res: Vec<u8> = Vec::new(); for (k, v) in self { let key = Variant::String(k.clone()); - res.extend(key.serialize()); - res.extend(v.serialize()); + res.extend(key.serialize()?); + res.extend(v.serialize()?); } util::insert_bytes(0, &mut res, &mut [0, 0, 0, 10]); - let len: i32 = res.len().try_into().unwrap(); + let len: i32 = res.len().try_into()?; util::insert_bytes(0, &mut res, &mut ((len).to_be_bytes())); - return res; + return Ok(res); } } impl HandshakeDeserialize for VariantMap { - fn parse(b: &[u8]) -> (usize, Self) { - let (_, len) = i32::parse(&b[0..4]); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (_, len) = i32::parse(&b[0..4])?; let mut pos: usize = 8; let mut map = VariantMap::new(); let ulen: usize = len as usize; loop { if (pos) >= ulen { break; } - let (nlen, name) = Variant::parse(&b[pos..]); + let (nlen, name) = Variant::parse(&b[pos..])?; pos += nlen; - let (vlen, value) = Variant::parse(&b[pos..]); + let (vlen, value) = Variant::parse(&b[pos..])?; pos += vlen; match name { Variant::String(x) => map.insert(x, value), Variant::StringUTF8(x) => map.insert(x, value), - _ => panic!() + _ => return Err(ErrorKind::WrongVariant) }; } - return (pos, map); + return Ok((pos, map)); } } impl HandshakeQRead for VariantMap { - fn read<T: Read>(s: &mut T, b: &mut [u8]) -> usize { - s.read(&mut b[0..4]).unwrap(); - let (_, len) = i32::parse(&b[0..4]); + fn read<T: Read>(s: &mut T, b: &mut [u8]) -> Result<usize, ErrorKind> { + s.read(&mut b[0..4])?; + let (_, len) = i32::parse(&b[0..4])?; // Read the 00 00 00 0a VariantType bytes and discard - s.read(&mut b[4..8]).unwrap(); + s.read(&mut b[4..8])?; let mut pos = 8; let len: usize = len as usize; loop { if pos >= (len - 4) { break; } - pos += Variant::read(s, &mut b[pos..]); - pos += Variant::read(s, &mut b[pos..]); + pos += Variant::read(s, &mut b[pos..])?; + pos += Variant::read(s, &mut b[pos..])?; } - return pos; + return Ok(pos); } } |
