diff --git a/ssh-audit.py b/ssh-audit.py index f81f2d2..cdd9b8c 100755 --- a/ssh-audit.py +++ b/ssh-audit.py @@ -172,7 +172,8 @@ class ReadBuf(object): n = self.read_int() return self.read(n) - def _parse_mpint(self, v, pad, sf): + @classmethod + def _parse_mpint(cls, v, pad, sf): r = 0 if len(v) % 4: v = pad * (4 - (len(v) % 4)) + v @@ -225,27 +226,38 @@ class WriteBuf(object): def write_list(self, v): self.write_string(u','.join(v)) - def _create_mpint(self, v): - length = v.bit_length() // 8 + (1 if v != 0 else 0) + @classmethod + def _bitlength(cls, n): + try: + return n.bit_length() + except AttributeError: + return len(bin(n)) - (2 if n > 0 else 3) + + @classmethod + def _create_mpint(cls, n, bits=None): + if bits is None: + bits = cls._bitlength(n) + length = bits // 8 + (1 if n != 0 else 0) ql = (length + 7) // 8 fmt, v2 = '>{0}Q'.format(ql), [b'\x00'] * ql for i in range(ql): - v2[ql - i - 1] = (v & 0xffffffffffffffff) - v >>= 64 + v2[ql - i - 1] = (n & 0xffffffffffffffff) + n >>= 64 data = bytes(struct.pack(fmt, *v2)[-length:]) if data.startswith(b'\xff\x80'): data = data[1:] return data - def write_mpint1(self, v): + def write_mpint1(self, n): # NOTE: Data Type Enc @ http://www.snailbook.com/docs/protocol-1.5.txt - data = self._create_mpint(v) - self.write(struct.pack('>H', v.bit_length())) + bits = self._bitlength(n) + data = self._create_mpint(n, bits) + self.write(struct.pack('>H', bits)) return self.write(data) - def write_mpint2(self, v): + def write_mpint2(self, n): # NOTE: Section 5 @ https://www.ietf.org/rfc/rfc4251.txt - data = self._create_mpint(v) + data = self._create_mpint(n) return self.write_string(data) def write_flush(self):