diff options
Diffstat (limited to 'src/protocol/message/handshake.rs')
| -rw-r--r-- | src/protocol/message/handshake.rs | 94 |
1 files changed, 66 insertions, 28 deletions
diff --git a/src/protocol/message/handshake.rs b/src/protocol/message/handshake.rs index b346d8e..7713ed9 100644 --- a/src/protocol/message/handshake.rs +++ b/src/protocol/message/handshake.rs @@ -1,5 +1,7 @@ +use std::result::Result; + +use crate::protocol::error::ErrorKind; use crate::protocol::primitive::{String, StringList, Variant, VariantList}; -use crate::protocol::primitive::{serialize, deserialize, qread}; mod types; pub use types::{VariantMap, HandshakeDeserialize, HandshakeSerialize, HandshakeQRead}; @@ -14,7 +16,7 @@ pub struct ClientInit { } impl HandshakeSerialize for ClientInit { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(5); values.insert("MsgType".to_string(), Variant::String("ClientInit".to_string())); values.insert("ClientVersion".to_string(), Variant::String(self.client_version.clone())); @@ -26,15 +28,27 @@ impl HandshakeSerialize for ClientInit { } impl HandshakeDeserialize for ClientInit { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); - - return (len, Self { - client_version: match_variant!(values, Variant::String, "ClientVersion"), - client_date: match_variant!(values, Variant::String, "ClientDate"), - feature_list: match_variant!(values, Variant::StringList, "FeatureList"), - client_features: match_variant!(values, Variant::u32, "Features") - }); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; + + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInit" { + return Ok((len, Self { + client_version: match_variant!(values, Variant::String, "ClientVersion"), + client_date: match_variant!(values, Variant::String, "ClientDate"), + feature_list: match_variant!(values, Variant::StringList, "FeatureList"), + client_features: match_variant!(values, Variant::u32, "Features") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } @@ -44,7 +58,7 @@ pub struct ClientInitReject { } impl HandshakeSerialize for ClientInitReject { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(2); values.insert("MsgType".to_string(), Variant::String("ClientInitReject".to_string())); values.insert("ErrorString".to_string(), Variant::String(self.error_string.clone())); @@ -53,12 +67,24 @@ impl HandshakeSerialize for ClientInitReject { } impl HandshakeDeserialize for ClientInitReject { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; - return (len, Self { - error_string: match_variant!(values, Variant::String, "ErrorString") - }); + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInitReject" { + return Ok((len, Self { + error_string: match_variant!(values, Variant::String, "ErrorString") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } @@ -72,7 +98,7 @@ pub struct ClientInitAck { } impl HandshakeSerialize for ClientInitAck { - fn serialize(&self) -> Vec<u8> { + fn serialize(&self) -> Result<Vec<u8>, ErrorKind> { let mut values: VariantMap = VariantMap::with_capacity(2); values.insert("MsgType".to_string(), Variant::String("ClientInitAck".to_string())); values.insert("CoreFeatures".to_string(), Variant::u32(self.core_features)); @@ -85,15 +111,27 @@ impl HandshakeSerialize for ClientInitAck { } impl HandshakeDeserialize for ClientInitAck { - fn parse(b: &[u8]) -> (usize, Self) { - let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b); - - return (len, Self { - core_features: 0x00008000, - core_configured: match_variant!(values, Variant::bool, "Configured"), - storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"), - authenticators: match_variant!(values, Variant::VariantList, "Authenticators"), - feature_list: match_variant!(values, Variant::StringList, "FeatureList") - }); + fn parse(b: &[u8]) -> Result<(usize, Self), ErrorKind> { + let (len, values): (usize, VariantMap) = HandshakeDeserialize::parse(b)?; + + let msgtypev = &values["MsgType"]; + let msgtype; + match msgtypev { + Variant::String(x) => msgtype = x, + Variant::StringUTF8(x) => msgtype = x, + _ => return Err(ErrorKind::WrongVariant) + }; + + if msgtype == "ClientInitAck" { + return Ok((len, Self { + core_features: 0x00008000, + core_configured: match_variant!(values, Variant::bool, "Configured"), + storage_backends: match_variant!(values, Variant::VariantList, "StorageBackends"), + authenticators: match_variant!(values, Variant::VariantList, "Authenticators"), + feature_list: match_variant!(values, Variant::StringList, "FeatureList") + })); + } else { + return Err(ErrorKind::WrongMsgType); + } } } |
