diff options
Diffstat (limited to '')
| -rw-r--r-- | src/protocol/frame/mod.rs | 405 | ||||
| -rw-r--r-- | src/protocol/message/handshake/types.rs | 23 | ||||
| -rw-r--r-- | src/protocol/mod.rs | 6 |
3 files changed, 422 insertions, 12 deletions
diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs new file mode 100644 index 0000000..b8296aa --- /dev/null +++ b/src/protocol/frame/mod.rs @@ -0,0 +1,405 @@ +use std::error::Error as StdError; +use std::io::{self, Cursor}; +use std::convert::TryInto; +use std::fmt; + +use bytes::{Buf, BufMut, BytesMut}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +use tokio_util::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite}; + +use flate2::Compress; +use flate2::Decompress; +use flate2::Compression; +use flate2::FlushCompress; +use flate2::FlushDecompress; + +#[derive(Debug, Clone, Copy)] +pub struct Builder { + // Maximum frame length + compression: bool, + compression_level: Compression, + + // Maximum frame length + max_frame_len: usize, +} + +// An error when the number of bytes read is more than max frame length. +pub struct QuasselCodecError { + _priv: (), +} + +#[derive(Debug)] +pub struct QuasselCodec { + builder: Builder, + state: DecodeState, + comp: Compress, + decomp: Decompress, +} + +#[derive(Debug, Clone, Copy)] +enum DecodeState { + Head, + Data(usize), +} + +impl QuasselCodec { + // Creates a new quassel codec with default values + pub fn new() -> Self { + Self { + builder: Builder::new(), + state: DecodeState::Head, + comp: Compress::new(Compression::default(), true), + decomp: Decompress::new(true), + } + } + + /// Creates a new quassel codec builder with default configuration + /// values. + pub fn builder() -> Builder { + Builder::new() + } + + pub fn compression(&self) -> bool { + self.builder.compression + } + + pub fn compression_level(&self) -> Compression { + self.builder.compression_level + } + + pub fn set_compression(&mut self, val: bool) { + self.builder.compression(val); + } + + pub fn set_compression_level(&mut self, val: Compression) { + self.builder.compression_level(val); + } + + fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> { +// let head_len = self.builder.num_head_bytes(); + let field_len = 4; + + if src.len() < field_len { + // Not enough data + return Ok(None); + } + + let n = { + let mut src = Cursor::new(&mut *src); + + let n = src.get_uint(field_len); + + if n > self.builder.max_frame_len as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + QuasselCodecError { _priv: () }, + )); + } + + // The check above ensures there is no overflow + let n = n as usize; + n + }; + + // Strip header + let _ = src.split_to(4); + + // Ensure that the buffer has enough space to read the incoming + // payload + src.reserve(n); + + Ok(Some(n)) + } + + fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> { + // At this point, the buffer has already had the required capacity + // reserved. All there is to do is read. + if src.len() < n { + return Ok(None); + } + + Ok(Some(src.split_to(n))) + } +} + + +impl Decoder for QuasselCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> { + let mut buf = vec![0; src.len() * 2]; + + println!("src: {:?}", &src[..]); + + if self.builder.compression == true { + let before_in = self.decomp.total_in(); + let before_out = self.decomp.total_out(); + self.decomp.decompress(&src, &mut buf, FlushDecompress::None)?; + let after_in = self.decomp.total_in(); + let after_out = self.decomp.total_out(); + + buf.truncate((after_out - before_out).try_into().unwrap()); + } else { + buf = src.to_vec(); + } + + let buf = &mut BytesMut::from(&buf[..]); + + let n = match self.state { + DecodeState::Head => match self.decode_head(buf)? { + Some(n) => { + self.state = DecodeState::Data(n); + n + } + None => return Ok(None), + }, + DecodeState::Data(n) => n, + }; + + match self.decode_data(n, buf)? { + Some(data) => { + // Update the decode state + self.state = DecodeState::Head; + + // Make sure the buffer has enough space to read the next head + buf.reserve(4); + + Ok(Some(data)) + } + None => Ok(None), + } + } +} + +impl Encoder for QuasselCodec { + type Item = Vec<u8>; + type Error = io::Error; + + fn encode(&mut self, data: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> { + let buf = &mut BytesMut::new(); + + let n = (&data).len(); + + if n > self.builder.max_frame_len { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + QuasselCodecError { _priv: () }, + )); + } + + // Reserve capacity in the destination buffer to fit the frame and + // length field (plus adjustment). + buf.reserve(4 + n); + + buf.put_uint(n as u64, 4); + + // Write the frame to the buffer + buf.extend_from_slice(&data[..]); + + if self.builder.compression { + let mut cbuf: Vec<u8> = vec![0; 4+n]; + let before_in = self.comp.total_in(); + let before_out = self.comp.total_out(); + self.comp.compress(buf, &mut cbuf, FlushCompress::Full)?; + let after_in = self.comp.total_in(); + let after_out = self.comp.total_out(); + + cbuf.truncate((after_out - before_out).try_into().unwrap()); + *dst = BytesMut::from(&cbuf[..]); + } else { + *dst = buf.clone(); + } + + Ok(()) + } +} + +impl Default for QuasselCodec { + fn default() -> Self { + Self::new() + } +} + + +// ===== impl Builder ===== + +impl Builder { + /// Creates a new length delimited codec builder with default configuration + /// values. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::QuasselCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// QuasselCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new() -> Builder { + Builder { + compression: false, + compression_level: Compression::default(), + max_frame_len: 64 * 1024 * 1024, + } + } + + pub fn compression(&mut self, val: bool) -> &mut Self { + self.compression = val; + self + } + + pub fn compression_level(&mut self, val: Compression) -> &mut Self { + self.compression_level = val; + self + } + + /// Sets the max frame length + /// + /// This configuration option applies to both encoding and decoding. The + /// default value is 8MB. + /// + /// When decoding, the length field read from the byte stream is checked + /// against this setting **before** any adjustments are applied. When + /// encoding, the length of the submitted payload is checked against this + /// setting. + /// + /// When frames exceed the max length, an `io::Error` with the custom value + /// of the `QuasselCodecError` type will be returned. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::QuasselCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// QuasselCodec::builder() + /// .max_frame_length(8 * 1024) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn max_frame_length(&mut self, val: usize) -> &mut Self { + self.max_frame_len = val; + self + } + + /// Create a configured length delimited `QuasselCodec` + /// + /// # Examples + /// + /// ``` + /// use tokio_util::codec::QuasselCodec; + /// # pub fn main() { + /// QuasselCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_codec(); + /// # } + /// ``` + pub fn new_codec(&self) -> QuasselCodec { + QuasselCodec { + builder: *self, + state: DecodeState::Head, + comp: Compress::new(self.compression_level, true), + decomp: Decompress::new(true), + } + } + + /// Create a configured length delimited `FramedRead` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::QuasselCodec; + /// + /// # fn bind_read<T: AsyncRead>(io: T) { + /// QuasselCodec::builder() + /// .length_field_offset(0) + /// .length_field_length(2) + /// .length_adjustment(0) + /// .num_skip(0) + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, QuasselCodec> + where + T: AsyncRead, + { + FramedRead::new(upstream, self.new_codec()) + } + + /// Create a configured length delimited `FramedWrite` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncWrite; + /// # use tokio_util::codec::QuasselCodec; + /// # fn write_frame<T: AsyncWrite>(io: T) { + /// QuasselCodec::builder() + /// .length_field_length(2) + /// .new_write(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, QuasselCodec> + where + T: AsyncWrite, + { + FramedWrite::new(inner, self.new_codec()) + } + + /// Create a configured length delimited `Framed` + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use tokio_util::codec::QuasselCodec; + /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) { + /// # let _ = + /// QuasselCodec::builder() + /// .length_field_length(2) + /// .new_framed(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn new_framed<T>(&self, inner: T) -> Framed<T, QuasselCodec> + where + T: AsyncRead + AsyncWrite, + { + Framed::new(inner, self.new_codec()) + } +} + +// ===== impl LengthDelimitedCodecError ===== + +impl fmt::Debug for QuasselCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("QuasselCodecError").finish() + } +} + +impl fmt::Display for QuasselCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("frame size too big") + } +} + +impl StdError for QuasselCodecError {} diff --git a/src/protocol/message/handshake/types.rs b/src/protocol/message/handshake/types.rs index 643b376..0c70914 100644 --- a/src/protocol/message/handshake/types.rs +++ b/src/protocol/message/handshake/types.rs @@ -37,10 +37,8 @@ impl HandshakeSerialize for VariantMap { res.extend(v.serialize()?); } - util::insert_bytes(0, &mut res, &mut [0, 0, 0, 10]); - - let len: i32 = res.len().try_into().unwrap(); - util::insert_bytes(0, &mut res, &mut ((len).to_be_bytes())); + let len: i32 = (self.len() * 2).try_into().unwrap(); + util::insert_bytes(0, &mut res, &mut (len).to_be_bytes()); return Ok(res); } @@ -50,11 +48,10 @@ impl HandshakeDeserialize for VariantMap { fn parse(b: &[u8]) -> Result<(usize, Self), Error> { let (_, len) = i32::parse(&b[0..4])?; - let mut pos: usize = 8; + let mut pos: usize = 4; let mut map = VariantMap::new(); - let ulen: usize = len as usize; - loop { - if (pos) >= ulen { break; } + + for _ in 0..(len / 2) { let (nlen, name) = Variant::parse(&b[pos..])?; pos += nlen; @@ -76,10 +73,12 @@ impl HandshakeQRead for VariantMap { fn read<T: Read>(s: &mut T, b: &mut [u8]) -> Result<usize, Error> { s.read(&mut b[0..4])?; let (_, len) = i32::parse(&b[0..4])?; - let ulen = len as usize; - // Read the 00 00 00 0a VariantType bytes and discard - s.read(&mut b[4..(ulen + 4)])?; + let mut pos = 4; + for _ in 0..(len / 2) { + pos += Variant::read(s, &mut b[pos..])?; + pos += Variant::read(s, &mut b[pos..])?; + } // let mut pos = 8; // let len: usize = len as usize; @@ -89,6 +88,6 @@ impl HandshakeQRead for VariantMap { // pos += Variant::read(s, &mut b[pos..])?; // } - return Ok(ulen + 4); + return Ok(pos); } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8739fd8..3630fab 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,3 +1,9 @@ pub mod message; pub mod primitive; + +#[allow(dead_code)] pub mod error; + +#[allow(unused_variables, dead_code)] +#[cfg(feature = "framing")] +pub mod frame; |
