Enabled the following mypy options: check_untyped_defs, disallow_untyped_defs, disallow_untyped_calls, disallow_incomplete_defs, disallow_untyped_decorators, disallow_untyped_decorators, strict_equality, and strict.

This commit is contained in:
Joe Testa 2020-07-01 13:00:44 -04:00
parent cabbe717d3
commit 30f2b7690a
2 changed files with 68 additions and 45 deletions

View File

@ -84,7 +84,7 @@ def usage(err: Optional[str] = None) -> None:
# Validates policy files and performs policy testing # Validates policy files and performs policy testing
class Policy: class Policy:
def __init__(self, policy_file: str = None, policy_data: str = None) -> None: def __init__(self, policy_file: Optional[str] = None, policy_data: Optional[str] = None) -> None:
self._name = None # type: Optional[str] self._name = None # type: Optional[str]
self._version = None # type: Optional[str] self._version = None # type: Optional[str]
self._banner = None # type: Optional[str] self._banner = None # type: Optional[str]
@ -274,7 +274,7 @@ macs = %s
return policy_data return policy_data
def evaluate(self, banner: Optional['SSH.Banner'], header: List[str], kex: Optional['SSH2.Kex']) -> Tuple[bool, List[str]]: def evaluate(self, banner: Optional['SSH.Banner'], header: Optional[List[str]], kex: Optional['SSH2.Kex']) -> Tuple[bool, List[str]]:
'''Evaluates a server configuration against this policy. Returns a tuple of a boolean (True if server adheres to policy) and an array of strings that holds error messages.''' '''Evaluates a server configuration against this policy. Returns a tuple of a boolean (True if server adheres to policy) and an array of strings that holds error messages.'''
ret = True ret = True
@ -285,9 +285,12 @@ macs = %s
ret = False ret = False
errors.append('Banner did not match. Expected: [%s]; Actual: [%s]' % (self._banner, banner_str)) errors.append('Banner did not match. Expected: [%s]; Actual: [%s]' % (self._banner, banner_str))
if (self._header is not None) and (header != self._header): actual_header = None
if header is not None:
actual_header = "\n".join(header)
if (self._header is not None) and (actual_header is not None) and (actual_header != self._header):
ret = False ret = False
errors.append('Header did not match. Expected: [%s]; Actual: [%s]' % (self._header, header)) errors.append('Header did not match. Expected: [%s]; Actual: [%s]' % (self._header, actual_header))
# All subsequent tests require a valid kex, so end here if we don't have one. # All subsequent tests require a valid kex, so end here if we don't have one.
if kex is None: if kex is None:
@ -597,7 +600,7 @@ class Output:
return lambda x: print(u'{}'.format(x)) return lambda x: print(u'{}'.format(x))
class OutputBuffer(list): class OutputBuffer(List[str]):
def __enter__(self) -> 'OutputBuffer': def __enter__(self) -> 'OutputBuffer':
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
self.__buf = io.StringIO() self.__buf = io.StringIO()
@ -608,12 +611,12 @@ class OutputBuffer(list):
def flush(self, sort_lines: bool = False) -> None: def flush(self, sort_lines: bool = False) -> None:
# Lines must be sorted in some cases to ensure consistent testing. # Lines must be sorted in some cases to ensure consistent testing.
if sort_lines: if sort_lines:
self.sort() self.sort() # pylint: disable=no-member
for line in self: for line in self: # pylint: disable=not-an-iterable
print(line) print(line)
def __exit__(self, *args: Any) -> None: def __exit__(self, *args: Any) -> None:
self.extend(self.__buf.getvalue().splitlines()) self.extend(self.__buf.getvalue().splitlines()) # pylint: disable=no-member
sys.stdout = self.__stdout sys.stdout = self.__stdout
@ -1664,7 +1667,7 @@ class SSH: # pylint: disable=too-few-public-methods
return re.sub(r'^[-_\.]+', '', patch) or None return re.sub(r'^[-_\.]+', '', patch) or None
@staticmethod @staticmethod
def _fix_date(d: str) -> Optional[str]: def _fix_date(d: Optional[str]) -> Optional[str]:
if d is not None and len(d) == 8: if d is not None and len(d) == 8:
return '{}-{}-{}'.format(d[:4], d[4:6], d[6:8]) return '{}-{}-{}'.format(d[:4], d[4:6], d[6:8])
else: else:
@ -2353,31 +2356,28 @@ class SSH: # pylint: disable=too-few-public-methods
def get_banner(self, sshv: int = 2) -> Tuple[Optional['SSH.Banner'], List[str], Optional[str]]: def get_banner(self, sshv: int = 2) -> Tuple[Optional['SSH.Banner'], List[str], Optional[str]]:
if self.__sock is None: if self.__sock is None:
return self.__banner, self.__header, 'not connected' return self.__banner, self.__header, 'not connected'
if self.__banner is not None:
return self.__banner, self.__header, None
banner = SSH_HEADER.format('1.5' if sshv == 1 else '2.0') banner = SSH_HEADER.format('1.5' if sshv == 1 else '2.0')
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.send_banner(banner) self.send_banner(banner)
# rto = self.__sock.gettimeout()
# self.__sock.settimeout(0.7) s = 0
s, e = self.recv()
# self.__sock.settimeout(rto)
if s < 0:
return self.__banner, self.__header, e
e = None e = None
while self.__banner is None: while s >= 0:
if not s > 0:
s, e = self.recv() s, e = self.recv()
if s < 0: if s < 0:
break continue
while self.__banner is None and self.unread_len > 0: while self.unread_len > 0:
line = self.read_line() line = self.read_line()
if len(line.strip()) == 0: if len(line.strip()) == 0:
continue continue
if self.__banner is None:
self.__banner = SSH.Banner.parse(line) self.__banner = SSH.Banner.parse(line)
if self.__banner is not None: if self.__banner is not None:
continue return self.__banner, self.__header, None
self.__header.append(line) self.__header.append(line)
s = 0
return self.__banner, self.__header, e return self.__banner, self.__header, e
def recv(self, size: int = 2048) -> Tuple[int, Optional[str]]: def recv(self, size: int = 2048) -> Tuple[int, Optional[str]]:
@ -2408,7 +2408,6 @@ class SSH: # pylint: disable=too-few-public-methods
return 0, None return 0, None
except socket.error as e: except socket.error as e:
return -1, str(e.args[-1]) return -1, str(e.args[-1])
self.__sock.send(data)
def send_banner(self, banner: str) -> None: def send_banner(self, banner: str) -> None:
self.send(banner.encode() + b'\r\n') self.send(banner.encode() + b'\r\n')
@ -2547,7 +2546,7 @@ class KexDH: # pragma: nocover
# Parse a KEXDH_REPLY or KEXDH_GEX_REPLY message from the server. This # Parse a KEXDH_REPLY or KEXDH_GEX_REPLY message from the server. This
# contains the host key, among other things. Function returns the host # contains the host key, among other things. Function returns the host
# key blob (from which the fingerprint can be calculated). # key blob (from which the fingerprint can be calculated).
def recv_reply(self, s, parse_host_key_size=True): def recv_reply(self, s: 'SSH.Socket', parse_host_key_size: bool = True) -> Optional[bytes]:
packet_type, payload = s.read_packet(2) packet_type, payload = s.read_packet(2)
# Skip any & all MSG_DEBUG messages. # Skip any & all MSG_DEBUG messages.
@ -3188,7 +3187,7 @@ def output(aconf: AuditConf, banner: Optional[SSH.Banner], header: List[str], cl
out.warn("\n\n!!! WARNING: unknown algorithm(s) found!: %s. Please email the full output above to the maintainer (jtesta@positronsecurity.com), or create a Github issue at <https://github.com/jtesta/ssh-audit/issues>.\n" % ','.join(unknown_algorithms)) out.warn("\n\n!!! WARNING: unknown algorithm(s) found!: %s. Please email the full output above to the maintainer (jtesta@positronsecurity.com), or create a Github issue at <https://github.com/jtesta/ssh-audit/issues>.\n" % ','.join(unknown_algorithms))
def evaluate_policy(aconf: AuditConf, banner: Optional['SSH.Banner'], header: List[str], kex: Optional['SSH2.Kex'] = None) -> bool: def evaluate_policy(aconf: AuditConf, banner: Optional['SSH.Banner'], header: Optional[List[str]], kex: Optional['SSH2.Kex'] = None) -> bool:
if aconf.policy is None: if aconf.policy is None:
raise RuntimeError('Internal error: cannot evaluate against null Policy!') raise RuntimeError('Internal error: cannot evaluate against null Policy!')
@ -3320,15 +3319,26 @@ class Utils:
return -1.0 return -1.0
def build_struct(banner, kex=None, pkm=None, client_host=None): def build_struct(banner: Optional['SSH.Banner'], kex: Optional['SSH2.Kex'] = None, pkm: Optional['SSH1.PublicKeyMessage'] = None, client_host: Optional[str] = None) -> Any:
banner_str = ''
banner_protocol = None
banner_software = None
banner_comments = None
if banner is not None:
banner_str = str(banner)
banner_protocol = banner.protocol
banner_software = banner.software
banner_comments = banner.comments
res = { res = {
"banner": { "banner": {
"raw": str(banner), "raw": banner_str,
"protocol": banner.protocol, "protocol": banner_protocol,
"software": banner.software, "software": banner_software,
"comments": banner.comments, "comments": banner_comments,
}, },
} } # type: Any
if client_host is not None: if client_host is not None:
res['client_ip'] = client_host res['client_ip'] = client_host
if kex is not None: if kex is not None:
@ -3339,8 +3349,8 @@ def build_struct(banner, kex=None, pkm=None, client_host=None):
for algorithm in kex.kex_algorithms: for algorithm in kex.kex_algorithms:
entry = { entry = {
'algorithm': algorithm, 'algorithm': algorithm,
} } # type: Any
if (alg_sizes is not None) and (algorithm in alg_sizes): if algorithm in alg_sizes:
hostkey_size, ca_size = alg_sizes[algorithm] hostkey_size, ca_size = alg_sizes[algorithm]
entry['keysize'] = hostkey_size entry['keysize'] = hostkey_size
if ca_size > 0: if ca_size > 0:
@ -3353,7 +3363,7 @@ def build_struct(banner, kex=None, pkm=None, client_host=None):
entry = { entry = {
'algorithm': algorithm, 'algorithm': algorithm,
} }
if (alg_sizes is not None) and (algorithm in alg_sizes): if algorithm in alg_sizes:
hostkey_size, ca_size = alg_sizes[algorithm] hostkey_size, ca_size = alg_sizes[algorithm]
entry['keysize'] = hostkey_size entry['keysize'] = hostkey_size
if ca_size > 0: if ca_size > 0:
@ -3387,12 +3397,20 @@ def build_struct(banner, kex=None, pkm=None, client_host=None):
} }
res['fingerprints'].append(entry) res['fingerprints'].append(entry)
else: else:
pkm_supported_ciphers = None
pkm_supported_authentications = None
pkm_fp = None
if pkm is not None:
pkm_supported_ciphers = pkm.supported_ciphers
pkm_supported_authentications = pkm.supported_authentications
pkm_fp = SSH.Fingerprint(pkm.host_key_fingerprint_data).sha256
res['key'] = ['ssh-rsa1'] res['key'] = ['ssh-rsa1']
res['enc'] = pkm.supported_ciphers res['enc'] = pkm_supported_ciphers
res['aut'] = pkm.supported_authentications res['aut'] = pkm_supported_authentications
res['fingerprints'] = [{ res['fingerprints'] = [{
'type': 'ssh-rsa1', 'type': 'ssh-rsa1',
'fp': SSH.Fingerprint(pkm.host_key_fingerprint_data).sha256, 'fp': pkm_fp,
}] }]
return res return res
@ -3421,7 +3439,7 @@ def audit(aconf: AuditConf, sshv: Optional[int] = None) -> int:
packet_type, payload = s.read_packet(sshv) packet_type, payload = s.read_packet(sshv)
if packet_type < 0: if packet_type < 0:
try: try:
if payload is not None and len(payload) > 0: if len(payload) > 0:
payload_txt = payload.decode('utf-8') payload_txt = payload.decode('utf-8')
else: else:
payload_txt = u'empty' payload_txt = u'empty'

11
tox.ini
View File

@ -90,14 +90,19 @@ commands =
[mypy] [mypy]
ignore_missing_imports = False ignore_missing_imports = False
follow_imports = normal follow_imports = normal
; disallow_untyped_calls = True disallow_incomplete_defs = True
; disallow_untyped_defs = True disallow_untyped_calls = True
; check_untyped_defs = True disallow_untyped_decorators = True
disallow_untyped_defs = True
check_untyped_defs = True
disallow_subclassing_any = True disallow_subclassing_any = True
warn_redundant_casts = True warn_redundant_casts = True
warn_return_any = True warn_return_any = True
warn_unreachable = True
warn_unused_ignores = True warn_unused_ignores = True
strict_optional = True strict_optional = True
strict_equality = True
strict = True
[pylint] [pylint]
reports = no reports = no