aboutsummaryrefslogtreecommitdiff
path: root/src/protocol
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol')
-rw-r--r--src/protocol/frame/mod.rs405
-rw-r--r--src/protocol/message/handshake/types.rs23
-rw-r--r--src/protocol/mod.rs6
3 files changed, 422 insertions, 12 deletions
diff --git a/src/protocol/frame/mod.rs b/src/protocol/frame/mod.rs
new file mode 100644
index 0000000..b8296aa
--- /dev/null
+++ b/src/protocol/frame/mod.rs
@@ -0,0 +1,405 @@
+use std::error::Error as StdError;
+use std::io::{self, Cursor};
+use std::convert::TryInto;
+use std::fmt;
+
+use bytes::{Buf, BufMut, BytesMut};
+
+use tokio::io::{AsyncRead, AsyncWrite};
+
+use tokio_util::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};
+
+use flate2::Compress;
+use flate2::Decompress;
+use flate2::Compression;
+use flate2::FlushCompress;
+use flate2::FlushDecompress;
+
+#[derive(Debug, Clone, Copy)]
+pub struct Builder {
+ // Maximum frame length
+ compression: bool,
+ compression_level: Compression,
+
+ // Maximum frame length
+ max_frame_len: usize,
+}
+
+// An error when the number of bytes read is more than max frame length.
+pub struct QuasselCodecError {
+ _priv: (),
+}
+
+#[derive(Debug)]
+pub struct QuasselCodec {
+ builder: Builder,
+ state: DecodeState,
+ comp: Compress,
+ decomp: Decompress,
+}
+
+#[derive(Debug, Clone, Copy)]
+enum DecodeState {
+ Head,
+ Data(usize),
+}
+
+impl QuasselCodec {
+ // Creates a new quassel codec with default values
+ pub fn new() -> Self {
+ Self {
+ builder: Builder::new(),
+ state: DecodeState::Head,
+ comp: Compress::new(Compression::default(), true),
+ decomp: Decompress::new(true),
+ }
+ }
+
+ /// Creates a new quassel codec builder with default configuration
+ /// values.
+ pub fn builder() -> Builder {
+ Builder::new()
+ }
+
+ pub fn compression(&self) -> bool {
+ self.builder.compression
+ }
+
+ pub fn compression_level(&self) -> Compression {
+ self.builder.compression_level
+ }
+
+ pub fn set_compression(&mut self, val: bool) {
+ self.builder.compression(val);
+ }
+
+ pub fn set_compression_level(&mut self, val: Compression) {
+ self.builder.compression_level(val);
+ }
+
+ fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
+// let head_len = self.builder.num_head_bytes();
+ let field_len = 4;
+
+ if src.len() < field_len {
+ // Not enough data
+ return Ok(None);
+ }
+
+ let n = {
+ let mut src = Cursor::new(&mut *src);
+
+ let n = src.get_uint(field_len);
+
+ if n > self.builder.max_frame_len as u64 {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ QuasselCodecError { _priv: () },
+ ));
+ }
+
+ // The check above ensures there is no overflow
+ let n = n as usize;
+ n
+ };
+
+ // Strip header
+ let _ = src.split_to(4);
+
+ // Ensure that the buffer has enough space to read the incoming
+ // payload
+ src.reserve(n);
+
+ Ok(Some(n))
+ }
+
+ fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
+ // At this point, the buffer has already had the required capacity
+ // reserved. All there is to do is read.
+ if src.len() < n {
+ return Ok(None);
+ }
+
+ Ok(Some(src.split_to(n)))
+ }
+}
+
+
+impl Decoder for QuasselCodec {
+ type Item = BytesMut;
+ type Error = io::Error;
+
+ fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
+ let mut buf = vec![0; src.len() * 2];
+
+ println!("src: {:?}", &src[..]);
+
+ if self.builder.compression == true {
+ let before_in = self.decomp.total_in();
+ let before_out = self.decomp.total_out();
+ self.decomp.decompress(&src, &mut buf, FlushDecompress::None)?;
+ let after_in = self.decomp.total_in();
+ let after_out = self.decomp.total_out();
+
+ buf.truncate((after_out - before_out).try_into().unwrap());
+ } else {
+ buf = src.to_vec();
+ }
+
+ let buf = &mut BytesMut::from(&buf[..]);
+
+ let n = match self.state {
+ DecodeState::Head => match self.decode_head(buf)? {
+ Some(n) => {
+ self.state = DecodeState::Data(n);
+ n
+ }
+ None => return Ok(None),
+ },
+ DecodeState::Data(n) => n,
+ };
+
+ match self.decode_data(n, buf)? {
+ Some(data) => {
+ // Update the decode state
+ self.state = DecodeState::Head;
+
+ // Make sure the buffer has enough space to read the next head
+ buf.reserve(4);
+
+ Ok(Some(data))
+ }
+ None => Ok(None),
+ }
+ }
+}
+
+impl Encoder for QuasselCodec {
+ type Item = Vec<u8>;
+ type Error = io::Error;
+
+ fn encode(&mut self, data: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
+ let buf = &mut BytesMut::new();
+
+ let n = (&data).len();
+
+ if n > self.builder.max_frame_len {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ QuasselCodecError { _priv: () },
+ ));
+ }
+
+ // Reserve capacity in the destination buffer to fit the frame and
+ // length field (plus adjustment).
+ buf.reserve(4 + n);
+
+ buf.put_uint(n as u64, 4);
+
+ // Write the frame to the buffer
+ buf.extend_from_slice(&data[..]);
+
+ if self.builder.compression {
+ let mut cbuf: Vec<u8> = vec![0; 4+n];
+ let before_in = self.comp.total_in();
+ let before_out = self.comp.total_out();
+ self.comp.compress(buf, &mut cbuf, FlushCompress::Full)?;
+ let after_in = self.comp.total_in();
+ let after_out = self.comp.total_out();
+
+ cbuf.truncate((after_out - before_out).try_into().unwrap());
+ *dst = BytesMut::from(&cbuf[..]);
+ } else {
+ *dst = buf.clone();
+ }
+
+ Ok(())
+ }
+}
+
+impl Default for QuasselCodec {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+
+// ===== impl Builder =====
+
+impl Builder {
+ /// Creates a new length delimited codec builder with default configuration
+ /// values.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::QuasselCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// QuasselCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_length(2)
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new() -> Builder {
+ Builder {
+ compression: false,
+ compression_level: Compression::default(),
+ max_frame_len: 64 * 1024 * 1024,
+ }
+ }
+
+ pub fn compression(&mut self, val: bool) -> &mut Self {
+ self.compression = val;
+ self
+ }
+
+ pub fn compression_level(&mut self, val: Compression) -> &mut Self {
+ self.compression_level = val;
+ self
+ }
+
+ /// Sets the max frame length
+ ///
+ /// This configuration option applies to both encoding and decoding. The
+ /// default value is 8MB.
+ ///
+ /// When decoding, the length field read from the byte stream is checked
+ /// against this setting **before** any adjustments are applied. When
+ /// encoding, the length of the submitted payload is checked against this
+ /// setting.
+ ///
+ /// When frames exceed the max length, an `io::Error` with the custom value
+ /// of the `QuasselCodecError` type will be returned.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::QuasselCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// QuasselCodec::builder()
+ /// .max_frame_length(8 * 1024)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn max_frame_length(&mut self, val: usize) -> &mut Self {
+ self.max_frame_len = val;
+ self
+ }
+
+ /// Create a configured length delimited `QuasselCodec`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use tokio_util::codec::QuasselCodec;
+ /// # pub fn main() {
+ /// QuasselCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_length(2)
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_codec();
+ /// # }
+ /// ```
+ pub fn new_codec(&self) -> QuasselCodec {
+ QuasselCodec {
+ builder: *self,
+ state: DecodeState::Head,
+ comp: Compress::new(self.compression_level, true),
+ decomp: Decompress::new(true),
+ }
+ }
+
+ /// Create a configured length delimited `FramedRead`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncRead;
+ /// use tokio_util::codec::QuasselCodec;
+ ///
+ /// # fn bind_read<T: AsyncRead>(io: T) {
+ /// QuasselCodec::builder()
+ /// .length_field_offset(0)
+ /// .length_field_length(2)
+ /// .length_adjustment(0)
+ /// .num_skip(0)
+ /// .new_read(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, QuasselCodec>
+ where
+ T: AsyncRead,
+ {
+ FramedRead::new(upstream, self.new_codec())
+ }
+
+ /// Create a configured length delimited `FramedWrite`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::AsyncWrite;
+ /// # use tokio_util::codec::QuasselCodec;
+ /// # fn write_frame<T: AsyncWrite>(io: T) {
+ /// QuasselCodec::builder()
+ /// .length_field_length(2)
+ /// .new_write(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, QuasselCodec>
+ where
+ T: AsyncWrite,
+ {
+ FramedWrite::new(inner, self.new_codec())
+ }
+
+ /// Create a configured length delimited `Framed`
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use tokio::io::{AsyncRead, AsyncWrite};
+ /// # use tokio_util::codec::QuasselCodec;
+ /// # fn write_frame<T: AsyncRead + AsyncWrite>(io: T) {
+ /// # let _ =
+ /// QuasselCodec::builder()
+ /// .length_field_length(2)
+ /// .new_framed(io);
+ /// # }
+ /// # pub fn main() {}
+ /// ```
+ pub fn new_framed<T>(&self, inner: T) -> Framed<T, QuasselCodec>
+ where
+ T: AsyncRead + AsyncWrite,
+ {
+ Framed::new(inner, self.new_codec())
+ }
+}
+
+// ===== impl LengthDelimitedCodecError =====
+
+impl fmt::Debug for QuasselCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("QuasselCodecError").finish()
+ }
+}
+
+impl fmt::Display for QuasselCodecError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("frame size too big")
+ }
+}
+
+impl StdError for QuasselCodecError {}
diff --git a/src/protocol/message/handshake/types.rs b/src/protocol/message/handshake/types.rs
index 643b376..0c70914 100644
--- a/src/protocol/message/handshake/types.rs
+++ b/src/protocol/message/handshake/types.rs
@@ -37,10 +37,8 @@ impl HandshakeSerialize for VariantMap {
res.extend(v.serialize()?);
}
- util::insert_bytes(0, &mut res, &mut [0, 0, 0, 10]);
-
- let len: i32 = res.len().try_into().unwrap();
- util::insert_bytes(0, &mut res, &mut ((len).to_be_bytes()));
+ let len: i32 = (self.len() * 2).try_into().unwrap();
+ util::insert_bytes(0, &mut res, &mut (len).to_be_bytes());
return Ok(res);
}
@@ -50,11 +48,10 @@ impl HandshakeDeserialize for VariantMap {
fn parse(b: &[u8]) -> Result<(usize, Self), Error> {
let (_, len) = i32::parse(&b[0..4])?;
- let mut pos: usize = 8;
+ let mut pos: usize = 4;
let mut map = VariantMap::new();
- let ulen: usize = len as usize;
- loop {
- if (pos) >= ulen { break; }
+
+ for _ in 0..(len / 2) {
let (nlen, name) = Variant::parse(&b[pos..])?;
pos += nlen;
@@ -76,10 +73,12 @@ impl HandshakeQRead for VariantMap {
fn read<T: Read>(s: &mut T, b: &mut [u8]) -> Result<usize, Error> {
s.read(&mut b[0..4])?;
let (_, len) = i32::parse(&b[0..4])?;
- let ulen = len as usize;
- // Read the 00 00 00 0a VariantType bytes and discard
- s.read(&mut b[4..(ulen + 4)])?;
+ let mut pos = 4;
+ for _ in 0..(len / 2) {
+ pos += Variant::read(s, &mut b[pos..])?;
+ pos += Variant::read(s, &mut b[pos..])?;
+ }
// let mut pos = 8;
// let len: usize = len as usize;
@@ -89,6 +88,6 @@ impl HandshakeQRead for VariantMap {
// pos += Variant::read(s, &mut b[pos..])?;
// }
- return Ok(ulen + 4);
+ return Ok(pos);
}
}
diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs
index 8739fd8..3630fab 100644
--- a/src/protocol/mod.rs
+++ b/src/protocol/mod.rs
@@ -1,3 +1,9 @@
pub mod message;
pub mod primitive;
+
+#[allow(dead_code)]
pub mod error;
+
+#[allow(unused_variables, dead_code)]
+#[cfg(feature = "framing")]
+pub mod frame;