SSH_Socket's constructor now takes an OutputBuffer for verbose & debugging output.

This commit is contained in:
Joe Testa 2021-03-02 11:25:37 -05:00
parent 83bd049486
commit 8e9fe20fac
6 changed files with 37 additions and 33 deletions

View File

@ -43,12 +43,12 @@ class GEXTest:
if s.is_connected(): if s.is_connected():
return True return True
err = s.connect(out) err = s.connect()
if err is not None: if err is not None:
out.v(err, write_now=True) out.v(err, write_now=True)
return False return False
_, _, err = s.get_banner(out) _, _, err = s.get_banner()
if err is not None: if err is not None:
out.v(err, write_now=True) out.v(err, write_now=True)
s.close() s.close()
@ -56,7 +56,7 @@ class GEXTest:
# Send our KEX using the specified group-exchange and most of the # Send our KEX using the specified group-exchange and most of the
# server's own values. # server's own values.
s.send_kexinit(out, key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages) s.send_kexinit(key_exchanges=[gex_alg], hostkeys=kex.key_algorithms, ciphers=kex.server.encryption, macs=kex.server.mac, compressions=kex.server.compression, languages=kex.server.languages)
# Parse the server's KEX. # Parse the server's KEX.
_, payload = s.read_packet(2) _, payload = s.read_packet(2)

View File

@ -109,19 +109,19 @@ class HostKeyTest:
# If the connection is closed, re-open it and get the kex again. # If the connection is closed, re-open it and get the kex again.
if not s.is_connected(): if not s.is_connected():
err = s.connect(out) err = s.connect()
if err is not None: if err is not None:
out.v(err, write_now=True) out.v(err, write_now=True)
return return
_, _, err = s.get_banner(out) _, _, err = s.get_banner()
if err is not None: if err is not None:
out.v(err, write_now=True) out.v(err, write_now=True)
s.close() s.close()
return return
# Send our KEX using the specified group-exchange and most of the server's own values. # Send our KEX using the specified group-exchange and most of the server's own values.
s.send_kexinit(out, key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages) s.send_kexinit(key_exchanges=[kex_str], hostkeys=[host_key_type], ciphers=server_kex.server.encryption, macs=server_kex.server.mac, compressions=server_kex.server.compression, languages=server_kex.server.languages)
# Parse the server's KEX. # Parse the server's KEX.
_, payload = s.read_packet() _, payload = s.read_packet()

View File

@ -820,14 +820,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
out.debug = aconf.debug out.debug = aconf.debug
out.level = aconf.level out.level = aconf.level
out.use_colors = aconf.colors out.use_colors = aconf.colors
s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set) s = SSH_Socket(out, aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set)
if aconf.client_audit: if aconf.client_audit:
out.v("Listening for client connection on port %d..." % aconf.port, write_now=True) out.v("Listening for client connection on port %d..." % aconf.port, write_now=True)
s.listen_and_accept() s.listen_and_accept()
else: else:
out.v("Starting audit of %s:%d..." % ('[%s]' % aconf.host if Utils.is_ipv6_address(aconf.host) else aconf.host, aconf.port), write_now=True) out.v("Starting audit of %s:%d..." % ('[%s]' % aconf.host if Utils.is_ipv6_address(aconf.host) else aconf.host, aconf.port), write_now=True)
err = s.connect(out) err = s.connect()
if err is not None: if err is not None:
out.fail(err) out.fail(err)
@ -842,14 +842,14 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print
if sshv is None: if sshv is None:
sshv = 2 if aconf.ssh2 else 1 sshv = 2 if aconf.ssh2 else 1
err = None err = None
banner, header, err = s.get_banner(out, sshv) banner, header, err = s.get_banner(sshv)
if banner is None: if banner is None:
if err is None: if err is None:
err = '[exception] did not receive banner.' err = '[exception] did not receive banner.'
else: else:
err = '[exception] did not receive banner: {}'.format(err) err = '[exception] did not receive banner: {}'.format(err)
if err is None: if err is None:
s.send_kexinit(out) # Send the algorithms we support (except we don't since this isn't a real SSH connection). s.send_kexinit() # Send the algorithms we support (except we don't since this isn't a real SSH connection).
packet_type, payload = s.read_packet(sshv) packet_type, payload = s.read_packet(sshv)
if packet_type < 0: if packet_type < 0:

View File

@ -52,8 +52,9 @@ class SSH_Socket(ReadBuf, WriteBuf):
SM_BANNER_SENT = 1 SM_BANNER_SENT = 1
def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value def __init__(self, outputbuffer: 'OutputBuffer', host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: # pylint: disable=dangerous-default-value
super(SSH_Socket, self).__init__() super(SSH_Socket, self).__init__()
self.__outputbuffer = outputbuffer
self.__sock: Optional[socket.socket] = None self.__sock: Optional[socket.socket] = None
self.__sock_map: Dict[int, socket.socket] = {} self.__sock_map: Dict[int, socket.socket] = {}
self.__block_size = 8 self.__block_size = 8
@ -90,7 +91,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
if socktype == socket.SOCK_STREAM: if socktype == socket.SOCK_STREAM:
yield af, addr yield af, addr
except socket.error as e: except socket.error as e:
OutputBuffer().fail('[exception] {}'.format(e)).write() self.__outputbuffer.fail('[exception] {}'.format(e)).write()
sys.exit(exitcodes.CONNECTION_ERROR) sys.exit(exitcodes.CONNECTION_ERROR)
# Listens on a server socket and accepts one connection (used for # Listens on a server socket and accepts one connection (used for
@ -148,7 +149,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
c.settimeout(self.__timeout) c.settimeout(self.__timeout)
self.__sock = c self.__sock = c
def connect(self, out: 'OutputBuffer') -> Optional[str]: def connect(self) -> Optional[str]:
'''Returns None on success, or an error string.''' '''Returns None on success, or an error string.'''
err = None err = None
for af, addr in self._resolve(): for af, addr in self._resolve():
@ -156,7 +157,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
try: try:
s = socket.socket(af, socket.SOCK_STREAM) s = socket.socket(af, socket.SOCK_STREAM)
s.settimeout(self.__timeout) s.settimeout(self.__timeout)
out.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True) self.__outputbuffer.d(("Connecting to %s:%d..." % ('[%s]' % addr[0] if Utils.is_ipv6_address(addr[0]) else addr[0], addr[1])), write_now=True)
s.connect(addr) s.connect(addr)
self.__sock = s self.__sock = s
return None return None
@ -170,8 +171,8 @@ class SSH_Socket(ReadBuf, WriteBuf):
errm = 'cannot connect to {} port {}: {}'.format(*errt) errm = 'cannot connect to {} port {}: {}'.format(*errt)
return '[exception] {}'.format(errm) return '[exception] {}'.format(errm)
def get_banner(self, out: 'OutputBuffer', sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]: def get_banner(self, sshv: int = 2) -> Tuple[Optional['Banner'], List[str], Optional[str]]:
out.d('Getting banner...', write_now=True) self.__outputbuffer.d('Getting banner...', write_now=True)
if self.__sock is None: if self.__sock is None:
return self.__banner, self.__header, 'not connected' return self.__banner, self.__header, 'not connected'
@ -229,10 +230,10 @@ class SSH_Socket(ReadBuf, WriteBuf):
return -1, str(e.args[-1]) return -1, str(e.args[-1])
# Send a KEXINIT with the lists of key exchanges, hostkeys, ciphers, MACs, compressions, and languages that we "support". # Send a KEXINIT with the lists of key exchanges, hostkeys, ciphers, MACs, compressions, and languages that we "support".
def send_kexinit(self, out: 'OutputBuffer', key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value def send_kexinit(self, key_exchanges: List[str] = ['curve25519-sha256', 'curve25519-sha256@libssh.org', 'ecdh-sha2-nistp256', 'ecdh-sha2-nistp384', 'ecdh-sha2-nistp521', 'diffie-hellman-group-exchange-sha256', 'diffie-hellman-group16-sha512', 'diffie-hellman-group18-sha512', 'diffie-hellman-group14-sha256'], hostkeys: List[str] = ['rsa-sha2-512', 'rsa-sha2-256', 'ssh-rsa', 'ecdsa-sha2-nistp256', 'ssh-ed25519'], ciphers: List[str] = ['chacha20-poly1305@openssh.com', 'aes128-ctr', 'aes192-ctr', 'aes256-ctr', 'aes128-gcm@openssh.com', 'aes256-gcm@openssh.com'], macs: List[str] = ['umac-64-etm@openssh.com', 'umac-128-etm@openssh.com', 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-512-etm@openssh.com', 'hmac-sha1-etm@openssh.com', 'umac-64@openssh.com', 'umac-128@openssh.com', 'hmac-sha2-256', 'hmac-sha2-512', 'hmac-sha1'], compressions: List[str] = ['none', 'zlib@openssh.com'], languages: List[str] = ['']) -> None: # pylint: disable=dangerous-default-value
'''Sends the list of supported host keys, key exchanges, ciphers, and MACs. Emulates OpenSSH v8.2.''' '''Sends the list of supported host keys, key exchanges, ciphers, and MACs. Emulates OpenSSH v8.2.'''
out.d('KEX initialisation...', write_now=True) self.__outputbuffer.d('KEX initialisation...', write_now=True)
kexparty = SSH2_KexParty(ciphers, macs, compressions, languages) kexparty = SSH2_KexParty(ciphers, macs, compressions, languages)
kex = SSH2_Kex(os.urandom(16), key_exchanges, hostkeys, kexparty, kexparty, False, 0) kex = SSH2_Kex(os.urandom(16), key_exchanges, hostkeys, kexparty, kexparty, False, 0)
@ -273,7 +274,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
payload_length = packet_length - padding_length - 1 payload_length = packet_length - padding_length - 1
check_size = 4 + 1 + payload_length + padding_length check_size = 4 + 1 + payload_length + padding_length
if check_size % self.__block_size != 0: if check_size % self.__block_size != 0:
OutputBuffer().fail('[exception] invalid ssh packet (block size)').write() self.__outputbuffer.fail('[exception] invalid ssh packet (block size)').write()
sys.exit(exitcodes.CONNECTION_ERROR) sys.exit(exitcodes.CONNECTION_ERROR)
self.ensure_read(payload_length) self.ensure_read(payload_length)
if sshv == 1: if sshv == 1:
@ -288,7 +289,7 @@ class SSH_Socket(ReadBuf, WriteBuf):
if sshv == 1: if sshv == 1:
rcrc = SSH1.crc32(padding + payload) rcrc = SSH1.crc32(padding + payload)
if crc != rcrc: if crc != rcrc:
OutputBuffer().fail('[exception] packet checksum CRC32 mismatch.').write() self.__outputbuffer.fail('[exception] packet checksum CRC32 mismatch.').write()
sys.exit(exitcodes.CONNECTION_ERROR) sys.exit(exitcodes.CONNECTION_ERROR)
else: else:
self.ensure_read(padding_length) self.ensure_read(padding_length)

View File

@ -8,6 +8,7 @@ class TestResolve:
def init(self, ssh_audit): def init(self, ssh_audit):
self.AuditConf = ssh_audit.AuditConf self.AuditConf = ssh_audit.AuditConf
self.audit = ssh_audit.audit self.audit = ssh_audit.audit
self.OutputBuffer = ssh_audit.OutputBuffer
self.ssh_socket = ssh_audit.SSH_Socket self.ssh_socket = ssh_audit.SSH_Socket
def _conf(self): def _conf(self):
@ -20,7 +21,7 @@ class TestResolve:
vsocket = virtual_socket vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known')
conf = self._conf() conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
output_spy.begin() output_spy.begin()
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
list(s._resolve()) list(s._resolve())
@ -32,7 +33,7 @@ class TestResolve:
vsocket = virtual_socket vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = [] vsocket.gsock.addrinfodata['localhost#22'] = []
conf = self._conf() conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
output_spy.begin() output_spy.begin()
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 0 assert len(r) == 0
@ -40,7 +41,7 @@ class TestResolve:
def test_resolve_ipv4(self, virtual_socket): def test_resolve_ipv4(self, virtual_socket):
conf = self._conf() conf = self._conf()
conf.ipv4 = True conf.ipv4 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 1 assert len(r) == 1
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -48,14 +49,14 @@ class TestResolve:
def test_resolve_ipv6(self, virtual_socket): def test_resolve_ipv6(self, virtual_socket):
conf = self._conf() conf = self._conf()
conf.ipv6 = True conf.ipv6 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 1 assert len(r) == 1
assert r[0] == (socket.AF_INET6, ('::1', 22)) assert r[0] == (socket.AF_INET6, ('::1', 22))
def test_resolve_ipv46_both(self, virtual_socket): def test_resolve_ipv46_both(self, virtual_socket):
conf = self._conf() conf = self._conf()
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 2 assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -65,7 +66,7 @@ class TestResolve:
conf = self._conf() conf = self._conf()
conf.ipv4 = True conf.ipv4 = True
conf.ipv6 = True conf.ipv6 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 2 assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
@ -73,7 +74,7 @@ class TestResolve:
conf = self._conf() conf = self._conf()
conf.ipv6 = True conf.ipv6 = True
conf.ipv4 = True conf.ipv4 = True
s = self.ssh_socket('localhost', 22, conf.ip_version_preference) s = self.ssh_socket(self.OutputBuffer(), 'localhost', 22, conf.ip_version_preference)
r = list(s._resolve()) r = list(s._resolve())
assert len(r) == 2 assert len(r) == 2
assert r[0] == (socket.AF_INET6, ('::1', 22)) assert r[0] == (socket.AF_INET6, ('::1', 22))

View File

@ -1,5 +1,6 @@
import pytest import pytest
from ssh_audit.outputbuffer import OutputBuffer
from ssh_audit.ssh_socket import SSH_Socket from ssh_audit.ssh_socket import SSH_Socket
@ -7,24 +8,25 @@ from ssh_audit.ssh_socket import SSH_Socket
class TestSocket: class TestSocket:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def init(self, ssh_audit): def init(self, ssh_audit):
self.OutputBuffer = OutputBuffer
self.ssh_socket = SSH_Socket self.ssh_socket = SSH_Socket
def test_invalid_host(self, virtual_socket): def test_invalid_host(self, virtual_socket):
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.ssh_socket(None, 22) self.ssh_socket(self.OutputBuffer(), None, 22)
def test_invalid_port(self, virtual_socket): def test_invalid_port(self, virtual_socket):
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.ssh_socket('localhost', 'abc') self.ssh_socket(self.OutputBuffer(), 'localhost', 'abc')
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.ssh_socket('localhost', -1) self.ssh_socket(self.OutputBuffer(), 'localhost', -1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.ssh_socket('localhost', 0) self.ssh_socket(self.OutputBuffer(), 'localhost', 0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
self.ssh_socket('localhost', 65536) self.ssh_socket(self.OutputBuffer(), 'localhost', 65536)
def test_not_connected_socket(self, virtual_socket): def test_not_connected_socket(self, virtual_socket):
sock = self.ssh_socket('localhost', 22) sock = self.ssh_socket(self.OutputBuffer(), 'localhost', 22)
banner, header, err = sock.get_banner() banner, header, err = sock.get_banner()
assert banner is None assert banner is None
assert len(header) == 0 assert len(header) == 0