Refactor ssh connection within class for future improvements.

This commit is contained in:
Andris Raugulis 2016-04-01 17:56:06 +03:00
parent 8442dfac0e
commit 06992d7da6

View File

@ -27,8 +27,6 @@ from __future__ import print_function
import os, io, sys, socket, struct import os, io, sys, socket, struct
SSH_BANNER = 'SSH-2.0-OpenSSH_7.2' SSH_BANNER = 'SSH-2.0-OpenSSH_7.2'
SOCK_CONN_TIMEOUT = 3.0
SOCK_READ_TIMEOUT = 5.0
def usage(): def usage():
p = os.path.basename(sys.argv[0]) p = os.path.basename(sys.argv[0])
@ -95,6 +93,7 @@ class Kex(object):
class ReadBuf(object): class ReadBuf(object):
def __init__(self, data = None): def __init__(self, data = None):
super(ReadBuf, self).__init__()
self._buf = io.BytesIO(data) if data else io.BytesIO() self._buf = io.BytesIO(data) if data else io.BytesIO()
self._len = len(data) if data else 0 self._len = len(data) if data else 0
@ -118,19 +117,69 @@ class ReadBuf(object):
list_size = self.read_int() list_size = self.read_int()
return self.read(list_size).decode().split(',') return self.read(list_size).decode().split(',')
class SockBuf(ReadBuf): class SSH(object):
def __init__(self, s): MSG_KEXINIT = 20
super(SockBuf, self).__init__() MSG_NEWKEYS = 21
self.__sock = s MSG_KEXDH_INIT = 30
MSG_KEXDH_REPLY = 32
def recv(self, size = 2048): class Socket(ReadBuf):
data = self.__sock.recv(size) def __init__(self, host, port, cto = 3.0, rto = 5.0):
pos = self._buf.tell() super(SSH.Socket, self).__init__()
self._buf.seek(0, 2) try:
self._buf.write(data) self.__sock = socket.create_connection((host, port), cto)
self._len += len(data) self.__sock.settimeout(rto)
self._buf.seek(pos, 0) except Exception as e:
out.fail('[fail] {}'.format(e))
sys.exit(1)
def __enter__(self):
return self
def recv(self, size = 2048):
data = self.__sock.recv(size)
pos = self._buf.tell()
self._buf.seek(0, 2)
self._buf.write(data)
self._len += len(data)
self._buf.seek(pos, 0)
def send(self, data):
self.__sock.send(data)
def read_packet(self):
block_size = 8
if self.unread_len < block_size:
self.recv()
header = self.read(block_size)
packet_size = struct.unpack('>I', header[:4])[0]
rest = header[4:]
lrest = len(rest)
padding = ord(rest[0:1])
packet_type = ord(rest[1:2])
if (packet_size - lrest) % block_size != 0:
out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1)
rlen = packet_size - lrest
if self.unread_len < rlen:
self.recv()
buf = self.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding]
return packet_type, payload
def __del__(self):
self.__cleanup()
def __exit__(self, ex_type, ex_value, tb):
self.__cleanup()
def __cleanup(self):
try:
self.__sock.shutdown(socket.SHUT_RDWR)
self.__sock.close()
except:
pass
def get_ssh_ver(versions): def get_ssh_ver(versions):
tv = [] tv = []
@ -309,26 +358,6 @@ def process_kex(kex):
process_algorithms('mac', kex.server.mac, maxlen) process_algorithms('mac', kex.server.mac, maxlen)
out.sep() out.sep()
def read_ssh_packet(sbuf):
block_size = 8
if sbuf.unread_len < block_size:
sbuf.recv()
header = sbuf.read(block_size)
packet_size = struct.unpack('>I', header[:4])[0]
rest = header[4:]
lrest = len(rest)
padding = ord(rest[0:1])
packet_type = ord(rest[1:2])
if (packet_size - lrest) % block_size != 0:
out.fail('[exception] invalid ssh packet (block size)')
sys.exit(1)
rlen = packet_size - lrest
if sbuf.unread_len < rlen:
sbuf.recv()
buf = sbuf.read(rlen)
packet = rest[2:] + buf[0:packet_size - lrest]
payload = packet[0:packet_size - padding]
return packet_type, payload
def parse_int(v): def parse_int(v):
try: try:
@ -355,30 +384,20 @@ def parse_args():
def main(): def main():
host, port = parse_args() host, port = parse_args()
s = None s = SSH.Socket(host, port)
try: s.send(SSH_BANNER.encode() + b'\r\n')
s = socket.create_connection((host, port), SOCK_CONN_TIMEOUT) s.recv()
s.settimeout(SOCK_READ_TIMEOUT) banner = s.read_line()
sbuf = SockBuf(s) out.head('# general')
s.send(SSH_BANNER.encode() + b'\r\n') out.good('[info] banner: ' + banner)
sbuf.recv() if banner.startswith('SSH-1.99-'):
banner = sbuf.read_line() out.fail('[fail] protocol SSH1 enabled')
out.head('# general') packet_type, payload = s.read_packet()
out.good('[info] banner: ' + banner) if packet_type != SSH.MSG_KEXINIT:
if banner.startswith('SSH-1.99-'): out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type))
out.fail('[fail] protocol SSH1 enabled')
packet_type, payload = read_ssh_packet(sbuf)
if packet_type != 20:
out.fail('[exception] did not receive MSG_KEXINIT (20), instead received unknown message ({0})'.format(packet_type))
sys.exit(1)
kex = Kex.parse(payload)
process_kex(kex)
except Exception as e:
out.fail('[fail] {}'.format(e))
sys.exit(1) sys.exit(1)
finally: kex = Kex.parse(payload)
if s: process_kex(kex)
s.close()
if __name__ == '__main__': if __name__ == '__main__':
out = Output() out = Output()