Parse pre-banner header. Handle sock read/write errors.

This commit is contained in:
Andris Raugulis 2016-08-12 16:20:32 +03:00
parent 07ca434061
commit d4d8c6a659
1 changed files with 63 additions and 22 deletions

View File

@ -24,7 +24,7 @@
THE SOFTWARE. THE SOFTWARE.
""" """
from __future__ import print_function from __future__ import print_function
import os, io, sys, socket, struct, random import os, io, sys, socket, struct, random, errno
SSH_BANNER = 'SSH-2.0-OpenSSH_7.3' SSH_BANNER = 'SSH-2.0-OpenSSH_7.3'
@ -162,6 +162,7 @@ class SSH(object):
def __init__(self, host, port, cto = 3.0, rto = 5.0): def __init__(self, host, port, cto = 3.0, rto = 5.0):
self.__block_size = 8 self.__block_size = 8
self.__state = 0 self.__state = 0
self.__header = []
self.__banner = None self.__banner = None
super(SSH.Socket, self).__init__() super(SSH.Socket, self).__init__()
try: try:
@ -177,20 +178,44 @@ class SSH(object):
def get_banner(self): def get_banner(self):
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.send_banner() self.send_banner()
if self.__banner is None: while self.__banner is None:
self.recv() s, e = self.recv()
self.__banner = self.read_line() if s < 0:
return self.__banner break
while self.__banner is None and self.unread_len > 0:
line = self.read_line()
if len(line.strip()) == 0:
continue
if line.startswith('SSH-'):
self.__banner = line
else:
self.__header.append(line)
return self.__banner, self.__header
def recv(self, size = 2048): def recv(self, size = 2048):
data = self.__sock.recv(size) try:
data = self.__sock.recv(size)
except socket.timeout as e:
r = 0 if e.strerror == 'timed out' else -1
return (r, e)
except socket.error as e:
r = 0 if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK) else -1
return (r, e)
if len(data) == 0:
return (-1, None)
pos = self._buf.tell() pos = self._buf.tell()
self._buf.seek(0, 2) self._buf.seek(0, 2)
self._buf.write(data) self._buf.write(data)
self._len += len(data) self._len += len(data)
self._buf.seek(pos, 0) self._buf.seek(pos, 0)
return (len(data), None)
def send(self, data): def send(self, data):
try:
self.__sock.send(data)
return (0, None)
except socket.error as e:
return (-1, e)
self.__sock.send(data) self.__sock.send(data)
def send_banner(self, banner = SSH_BANNER): def send_banner(self, banner = SSH_BANNER):
@ -199,8 +224,10 @@ class SSH(object):
self.__state = self.SM_BANNER_SENT self.__state = self.SM_BANNER_SENT
def read_packet(self): def read_packet(self):
if self.unread_len < self.__block_size: while self.unread_len < self.__block_size:
self.recv() s, e = self.recv()
if s < 0:
return -1, e
header = self.read(self.__block_size) header = self.read(self.__block_size)
if len(header) == 0: if len(header) == 0:
out.fail('[exception] empty ssh packet (no data)') out.fail('[exception] empty ssh packet (no data)')
@ -214,8 +241,10 @@ class SSH(object):
out.fail('[exception] invalid ssh packet (block size)') out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1) sys.exit(1)
rlen = packet_size - lrest rlen = packet_size - lrest
if self.unread_len < rlen: while self.unread_len < rlen:
self.recv() s, e = self.recv()
if s < 0:
return -1, e
buf = self.read(rlen) buf = self.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest] packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding] payload = packet[0:packet_size - padding]
@ -231,7 +260,7 @@ class SSH(object):
plen = len(payload) + padding + 1 plen = len(payload) + padding + 1
pad_bytes = b'\x00' * padding pad_bytes = b'\x00' * padding
data = struct.pack('>Ib', plen, padding) + payload + pad_bytes data = struct.pack('>Ib', plen, padding) + payload + pad_bytes
self.send(data) return self.send(data)
def __del__(self): def __del__(self):
self.__cleanup() self.__cleanup()
@ -508,11 +537,15 @@ def output_compatibility(kex, client=False):
if len(comp_text) > 0: if len(comp_text) > 0:
out.good('[info] compatibility: ' + ', '.join(comp_text)) out.good('[info] compatibility: ' + ', '.join(comp_text))
def output(banner, kex): def output(banner, header, kex):
out.head('# general') if banner is not None or kex is not None:
out.good('[info] banner: ' + banner) out.head('# general')
if banner.startswith('SSH-1.99-'): if len(header) > 0:
out.fail('[fail] protocol SSH1 enabled') out.info('[info] header: ' + '\n'.join(header))
if banner is not None:
out.good('[info] banner: ' + banner)
if banner.startswith('SSH-1.99-'):
out.fail('[fail] protocol SSH1 enabled')
if kex is None: if kex is None:
return return
output_compatibility(kex) output_compatibility(kex)
@ -564,14 +597,22 @@ def parse_args():
def main(): def main():
host, port = parse_args() host, port = parse_args()
s = SSH.Socket(host, port) s = SSH.Socket(host, port)
banner = s.get_banner() err = None
packet_type, payload = s.read_packet() banner, header = s.get_banner()
if packet_type != SSH.MSG_KEXINIT: if banner is None:
output(banner, None) err = '[exception] did not receive banner.'
out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)) if err is None:
packet_type, payload = s.read_packet()
if packet_type < 0:
err = '[exception] error reading packet ({0})'.format(payload)
elif packet_type != SSH.MSG_KEXINIT:
err = '[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type)
if err:
output(banner, header, None)
out.fail(err)
sys.exit(1) sys.exit(1)
kex = Kex.parse(payload) kex = Kex.parse(payload)
output(banner, kex) output(banner, header, kex)
if __name__ == '__main__': if __name__ == '__main__':
out = Output() out = Output()