diff --git a/ssh-audit.py b/ssh-audit.py index c844d7d..da4bc7b 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -155,18 +155,37 @@ class ReadBuf(object): def read(self, size): return self._buf.read(size) - def read_line(self): - return self._buf.readline().rstrip().decode('utf-8') + def read_byte(self): + return struct.unpack('B', self.read(1))[0] + + def read_bool(self): + return self.read_byte() != 0 def read_int(self): return struct.unpack('>I', self.read(4))[0] - def read_bool(self): - return struct.unpack('b', self.read(1))[0] != 0 - def read_list(self): list_size = self.read_int() return self.read(list_size).decode().split(',') + + def read_string(self): + n = self.read_int() + return self.read(n) + + def read_mpint2(self): + # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt + r, v = 0, self.read_string() + if len(v) == 0: + return r + pad, sf = (b'\xff', '>i') if ord(v[0:1]) & 0x80 else (b'\x00', '>I') + if len(v) % 4: + v = pad * (4 - (len(v) % 4)) + v + for i in range(0, len(v), 4): + r = (r << 32) | struct.unpack(sf, v[i:i + 4])[0] + return r + + def read_line(self): + return self._buf.readline().rstrip().decode('utf-8') class WriteBuf(object): @@ -176,31 +195,40 @@ class WriteBuf(object): def write(self, data): self._wbuf.write(data) + return self def write_byte(self, v): - self.write(struct.pack('>B', v)) + return self.write(struct.pack('B', v)) def write_bool(self, v): - self.write_byte(1 if v else 0) + return self.write_byte(1 if v else 0) def write_int(self, v): - self.write(struct.pack('>I', v)) + return self.write(struct.pack('>I', v)) def write_string(self, v): if not isinstance(v, bytes): v = bytes(bytearray(v, 'utf-8')) - n = len(v) - self.write(struct.pack('>I', n)) - self.write(v) + self.write_int(len(v)) + return self.write(v) def write_list(self, v): - self.write_string(','.join(v)) + self.write_string(u','.join(v)) - def write_mpint(self, v): - length = v.bit_length() // 8 + 1 + def write_mpint2(self, v): + # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt + length = v.bit_length() // 8 + (1 if v != 0 else 0) data = [(v >> i * 8) & 0xff for i in reversed(range(length))] + if length > 1 and data[0] == 0xff and data[1] & 0x80: + data.pop(0) data = bytes(bytearray(data)) - self.write_string(data) + return self.write_string(data) + + def write_flush(self): + payload = self._wbuf.getvalue() + self._wbuf.truncate(0) + self._wbuf.seek(0) + return payload class SSH(object): @@ -559,9 +587,7 @@ class SSH(object): return packet_type, payload def send_packet(self): - payload = self._wbuf.getvalue() - self._wbuf.truncate(0) - self._wbuf.seek(0) + payload = self.write_flush() padding = -(len(payload) + 5) % 8 if padding < 4: padding += 8 @@ -597,7 +623,7 @@ class KexDH(object): self.__x = r.randrange(2, self.__q) self.__e = pow(self.__g, self.__x, self.__p) s.write_byte(SSH.Protocol.MSG_KEXDH_INIT) - s.write_mpint(self.__e) + s.write_mpint2(self.__e) s.send_packet() diff --git a/test/test_protocol.py b/test/test_protocol.py new file mode 100644 index 0000000..bcff2cc --- /dev/null +++ b/test/test_protocol.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import pytest +import re + + +class TestProtocol(object): + @pytest.fixture(autouse=True) + def init(self, ssh_audit): + self.rbuf = ssh_audit.ReadBuf + self.wbuf = ssh_audit.WriteBuf + + def _b(self, v): + v = re.sub(r'\s', '', v) + data = [int(v[i * 2:i * 2 + 2], 16) for i in range(len(v) // 2)] + return bytes(bytearray(data)) + + def test_mpint2_write(self): + wbuf, _b = self.wbuf(), self._b + mpint = lambda x: wbuf.write_mpint2(x).write_flush() + assert mpint(0x0) == _b('00 00 00 00') + assert mpint(0x80) == _b('00 00 00 02 00 80') + assert mpint(0x9a378f9b2e332a7) == _b('00 00 00 08 09 a3 78 f9 b2 e3 32 a7') + assert mpint(-0x1234) == _b('00 00 00 02 ed cc') + assert mpint(-0xdeadbeef) == _b('00 00 00 05 ff 21 52 41 11') + assert mpint(-0x80) == _b('00 00 00 01 80') + + def test_mpint2_read(self): + rbuf, _b = self.rbuf, self._b + mpint = lambda x: rbuf(x).read_mpint2() + assert mpint(_b('00 00 00 00')) == 0x00 + assert mpint(_b('00 00 00 02 00 80')) == 0x80 + assert mpint(_b('00 00 00 08 09 a3 78 f9 b2 e3 32 a7')) == 0x9a378f9b2e332a7 + assert mpint(_b('00 00 00 02 ed cc')) == -0x1234 + assert mpint(_b('00 00 00 05 ff 21 52 41 11')) == -0xdeadbeef + assert mpint(_b('00 00 00 01 80')) == -0x80 + assert mpint(_b('00 00 00 02 ff 80')) == -0x80