Fix and update write buffer. Add buffer tests.

This commit is contained in:
Andris Raugulis 2016-10-05 06:06:26 +03:00
parent 262c65b7be
commit 7959c7448a
3 changed files with 124 additions and 43 deletions

View File

@ -416,7 +416,7 @@ class WriteBuf(object):
return self.write(v) return self.write(v)
def write_list(self, v): def write_list(self, v):
self.write_string(u','.join(v)) return self.write_string(u','.join(v))
@classmethod @classmethod
def _bitlength(cls, n): def _bitlength(cls, n):
@ -454,6 +454,12 @@ class WriteBuf(object):
data = self._create_mpint(n) data = self._create_mpint(n)
return self.write_string(data) return self.write_string(data)
def write_line(self, v):
if not isinstance(v, bytes):
v = bytes(bytearray(v, 'utf-8'))
v += b'\r\n'
return self.write(v)
def write_flush(self): def write_flush(self):
payload = self._wbuf.getvalue() payload = self._wbuf.getvalue()
self._wbuf.truncate(0) self._wbuf.truncate(0)

117
test/test_buffer.py Normal file
View File

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
import re
class TestBuffer(object):
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
self.rbuf = ssh_audit.ReadBuf
self.wbuf = ssh_audit.WriteBuf
def _b(self, v):
v = re.sub(r'\s', '', v)
data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)]
return bytes(bytearray(data))
def test_unread(self):
w = self.wbuf().write_byte(1).write_int(2).write_flush()
r = self.rbuf(w)
assert r.unread_len == 5
r.read_byte()
assert r.unread_len == 4
r.read_int()
assert r.unread_len == 0
def test_byte(self):
w = lambda x: self.wbuf().write_byte(x).write_flush()
r = lambda x: self.rbuf(x).read_byte()
tc = [(0x00, '00'),
(0x01, '01'),
(0x10, '10'),
(0xff, 'ff')]
for p in tc:
assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0]
def test_bool(self):
w = lambda x: self.wbuf().write_bool(x).write_flush()
r = lambda x: self.rbuf(x).read_bool()
tc = [(True, '01'),
(False, '00')]
for p in tc:
assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0]
def test_int(self):
w = lambda x: self.wbuf().write_int(x).write_flush()
r = lambda x: self.rbuf(x).read_int()
tc = [(0x00, '00 00 00 00'),
(0x01, '00 00 00 01'),
(0xabcd, '00 00 ab cd'),
(0xffffffff, 'ff ff ff ff')]
for p in tc:
assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0]
def test_string(self):
w = lambda x: self.wbuf().write_string(x).write_flush()
r = lambda x: self.rbuf(x).read_string()
tc = [(u'abc1', '00 00 00 04 61 62 63 31'),
(b'abc2', '00 00 00 04 61 62 63 32')]
for p in tc:
v = p[0]
assert w(v) == self._b(p[1])
if not isinstance(v, bytes):
v = bytes(bytearray(v, 'utf-8'))
assert r(self._b(p[1])) == v
def test_list(self):
w = lambda x: self.wbuf().write_list(x).write_flush()
r = lambda x: self.rbuf(x).read_list()
tc = [(['d', 'ef', 'ault'], '00 00 00 09 64 2c 65 66 2c 61 75 6c 74')]
for p in tc:
assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0]
def test_line(self):
w = lambda x: self.wbuf().write_line(x).write_flush()
r = lambda x: self.rbuf(x).read_line()
tc = [(u'example line', '65 78 61 6d 70 6c 65 20 6c 69 6e 65 0d 0a')]
for p in tc:
assert w(p[0]) == self._b(p[1])
assert r(self._b(p[1])) == p[0]
def test_bitlen(self, monkeypatch):
class Py26Int(int):
def bit_length(self):
raise AttributeError
assert self.wbuf._bitlength(42) == 6
assert self.wbuf._bitlength(Py26Int(42)) == 6
def test_mpint1(self):
mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush()
mpint1r = lambda x: self.rbuf(x).read_mpint1()
tc = [(0x0, '00 00'),
(0x1234, '00 0d 12 34'),
(0x12345, '00 11 01 23 45'),
(0xdeadbeef, '00 20 de ad be ef')]
for p in tc:
assert mpint1w(p[0]) == self._b(p[1])
assert mpint1r(self._b(p[1])) == p[0]
def test_mpint2(self):
mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush()
mpint2r = lambda x: self.rbuf(x).read_mpint2()
tc = [(0x0, '00 00 00 00'),
(0x80, '00 00 00 02 00 80'),
(0x9a378f9b2e332a7, '00 00 00 08 09 a3 78 f9 b2 e3 32 a7'),
(-0x1234, '00 00 00 02 ed cc'),
(-0xdeadbeef, '00 00 00 05 ff 21 52 41 11'),
(-0x8000, '00 00 00 02 80 00'),
(-0x80, '00 00 00 01 80')]
for p in tc:
assert mpint2w(p[0]) == self._b(p[1])
assert mpint2r(self._b(p[1])) == p[0]
assert mpint2r(self._b('00 00 00 02 ff 80')) == -0x80

View File

@ -1,42 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
import re
class TestProtocol(object):
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
self.rbuf = ssh_audit.ReadBuf
self.wbuf = ssh_audit.WriteBuf
def _b(self, v):
v = re.sub(r'\s', '', v)
data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)]
return bytes(bytearray(data))
def test_mpint1(self):
mpint1w = lambda x: self.wbuf().write_mpint1(x).write_flush()
mpint1r = lambda x: self.rbuf(x).read_mpint1()
tc = [(0x0, '00 00'),
(0x1234, '00 0d 12 34'),
(0x12345, '00 11 01 23 45'),
(0xdeadbeef, '00 20 de ad be ef')]
for p in tc:
assert mpint1w(p[0]) == self._b(p[1])
assert mpint1r(self._b(p[1])) == p[0]
def test_mpint2(self):
mpint2w = lambda x: self.wbuf().write_mpint2(x).write_flush()
mpint2r = lambda x: self.rbuf(x).read_mpint2()
tc = [(0x0, '00 00 00 00'),
(0x80, '00 00 00 02 00 80'),
(0x9a378f9b2e332a7, '00 00 00 08 09 a3 78 f9 b2 e3 32 a7'),
(-0x1234, '00 00 00 02 ed cc'),
(-0xdeadbeef, '00 00 00 05 ff 21 52 41 11'),
(-0x8000, '00 00 00 02 80 00'),
(-0x80, '00 00 00 01 80')]
for p in tc:
assert mpint2w(p[0]) == self._b(p[1])
assert mpint2r(self._b(p[1])) == p[0]
assert mpint2r(self._b('00 00 00 02 ff 80')) == -0x80