mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-10-30 21:15:27 +01:00 
			
		
		
		
	Refactored IPv4/6 preference logic to fix pylint warnings.
This commit is contained in:
		| @@ -43,7 +43,7 @@ class AuditConf: | |||||||
|         self.json = False |         self.json = False | ||||||
|         self.verbose = False |         self.verbose = False | ||||||
|         self.level = 'info' |         self.level = 'info' | ||||||
|         self.ipvo: Sequence[int] = () |         self.ip_version_preference: List[int] = []  # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6). | ||||||
|         self.ipv4 = False |         self.ipv4 = False | ||||||
|         self.ipv6 = False |         self.ipv6 = False | ||||||
|         self.make_policy = False  # When True, creates a policy file from an audit scan. |         self.make_policy = False  # When True, creates a policy file from an audit scan. | ||||||
| @@ -60,28 +60,14 @@ class AuditConf: | |||||||
|  |  | ||||||
|     def __setattr__(self, name: str, value: Union[str, int, float, bool, Sequence[int]]) -> None: |     def __setattr__(self, name: str, value: Union[str, int, float, bool, Sequence[int]]) -> None: | ||||||
|         valid = False |         valid = False | ||||||
|         if name in ['ssh1', 'ssh2', 'batch', 'client_audit', 'colors', 'verbose', 'timeout_set', 'json', 'make_policy', 'list_policies', 'manual']: |         if name in ['batch', 'client_audit', 'colors', 'json', 'list_policies', 'manual', 'make_policy', 'ssh1', 'ssh2', 'timeout_set', 'verbose']: | ||||||
|             valid, value = True, bool(value) |             valid, value = True, bool(value) | ||||||
|         elif name in ['ipv4', 'ipv6']: |         elif name in ['ipv4', 'ipv6']: | ||||||
|  |             valid, value = True, bool(value) | ||||||
|  |             if len(self.ip_version_preference) == 2:  # Being called more than twice is not valid. | ||||||
|                 valid = False |                 valid = False | ||||||
|             value = bool(value) |             elif value: | ||||||
|             ipv = 4 if name == 'ipv4' else 6 |                 self.ip_version_preference.append(4 if name == 'ipv4' else 6) | ||||||
|             if value: |  | ||||||
|                 value = tuple(list(self.ipvo) + [ipv]) |  | ||||||
|             else:  # pylint: disable=else-if-used |  | ||||||
|                 if len(self.ipvo) == 0: |  | ||||||
|                     value = (6,) if ipv == 4 else (4,) |  | ||||||
|                 else: |  | ||||||
|                     value = tuple([x for x in self.ipvo if x != ipv]) |  | ||||||
|             self.__setattr__('ipvo', value) |  | ||||||
|         elif name == 'ipvo': |  | ||||||
|             if isinstance(value, (tuple, list)): |  | ||||||
|                 uniq_value = Utils.unique_seq(value) |  | ||||||
|                 value = tuple([x for x in uniq_value if x in (4, 6)]) |  | ||||||
|                 valid = True |  | ||||||
|                 ipv_both = len(value) == 0 |  | ||||||
|                 object.__setattr__(self, 'ipv4', ipv_both or 4 in value) |  | ||||||
|                 object.__setattr__(self, 'ipv6', ipv_both or 6 in value) |  | ||||||
|         elif name == 'port': |         elif name == 'port': | ||||||
|             valid, port = True, Utils.parse_int(value) |             valid, port = True, Utils.parse_int(value) | ||||||
|             if port < 1 or port > 65535: |             if port < 1 or port > 65535: | ||||||
| @@ -98,7 +84,7 @@ class AuditConf: | |||||||
|             if value == -1.0: |             if value == -1.0: | ||||||
|                 raise ValueError('invalid timeout: {}'.format(value)) |                 raise ValueError('invalid timeout: {}'.format(value)) | ||||||
|             valid = True |             valid = True | ||||||
|         elif name in ['policy_file', 'policy', 'target_file', 'target_list', 'lookup']: |         elif name in ['ip_version_preference', 'lookup', 'policy_file', 'policy', 'target_file', 'target_list']: | ||||||
|             valid = True |             valid = True | ||||||
|         elif name == "threads": |         elif name == "threads": | ||||||
|             valid, num_threads = True, Utils.parse_int(value) |             valid, num_threads = True, Utils.parse_int(value) | ||||||
|   | |||||||
| @@ -815,7 +815,7 @@ def audit(out: OutputBuffer, aconf: AuditConf, sshv: Optional[int] = None, print | |||||||
|     out.verbose = aconf.verbose |     out.verbose = aconf.verbose | ||||||
|     out.level = aconf.level |     out.level = aconf.level | ||||||
|     out.use_colors = aconf.colors |     out.use_colors = aconf.colors | ||||||
|     s = SSH_Socket(aconf.host, aconf.port, aconf.ipvo, aconf.timeout, aconf.timeout_set) |     s = SSH_Socket(aconf.host, aconf.port, aconf.ip_version_preference, aconf.timeout, aconf.timeout_set) | ||||||
|     if aconf.client_audit: |     if aconf.client_audit: | ||||||
|         out.v("Listening for client connection on port %d..." % aconf.port, write_now=True) |         out.v("Listening for client connection on port %d..." % aconf.port, write_now=True) | ||||||
|         s.listen_and_accept() |         s.listen_and_accept() | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| """ | """ | ||||||
|    The MIT License (MIT) |    The MIT License (MIT) | ||||||
|  |  | ||||||
|    Copyright (C) 2017-2020 Joe Testa (jtesta@positronsecurity.com) |    Copyright (C) 2017-2021 Joe Testa (jtesta@positronsecurity.com) | ||||||
|    Copyright (C) 2017 Andris Raugulis (moo@arthepsy.eu) |    Copyright (C) 2017 Andris Raugulis (moo@arthepsy.eu) | ||||||
|  |  | ||||||
|    Permission is hereby granted, free of charge, to any person obtaining a copy |    Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| @@ -52,7 +52,7 @@ class SSH_Socket(ReadBuf, WriteBuf): | |||||||
|  |  | ||||||
|     SM_BANNER_SENT = 1 |     SM_BANNER_SENT = 1 | ||||||
|  |  | ||||||
|     def __init__(self, host: Optional[str], port: int, ipvo: Optional[Sequence[int]] = None, timeout: Union[int, float] = 5, timeout_set: bool = False) -> None: |     def __init__(self, host: Optional[str], port: int, ip_version_preference: List[int] = [], timeout: Union[int, float] = 5, timeout_set: bool = False) -> None:  # pylint: disable=dangerous-default-value | ||||||
|         super(SSH_Socket, self).__init__() |         super(SSH_Socket, self).__init__() | ||||||
|         self.__sock: Optional[socket.socket] = None |         self.__sock: Optional[socket.socket] = None | ||||||
|         self.__sock_map: Dict[int, socket.socket] = {} |         self.__sock_map: Dict[int, socket.socket] = {} | ||||||
| @@ -67,32 +67,27 @@ class SSH_Socket(ReadBuf, WriteBuf): | |||||||
|             raise ValueError('invalid port: {}'.format(port)) |             raise ValueError('invalid port: {}'.format(port)) | ||||||
|         self.__host = host |         self.__host = host | ||||||
|         self.__port = nport |         self.__port = nport | ||||||
|         if ipvo is not None: |         self.__ip_version_preference = ip_version_preference  # Holds only 5 possible values: [] (no preference), [4] (use IPv4 only), [6] (use IPv6 only), [46] (use both IPv4 and IPv6, but prioritize v4), and [64] (use both IPv4 and IPv6, but prioritize v6). | ||||||
|             self.__ipvo = ipvo |  | ||||||
|         else: |  | ||||||
|             self.__ipvo = () |  | ||||||
|         self.__timeout = timeout |         self.__timeout = timeout | ||||||
|         self.__timeout_set = timeout_set |         self.__timeout_set = timeout_set | ||||||
|         self.client_host: Optional[str] = None |         self.client_host: Optional[str] = None | ||||||
|         self.client_port = None |         self.client_port = None | ||||||
|  |  | ||||||
|     def _resolve(self, ipvo: Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]: |     def _resolve(self) -> Iterable[Tuple[int, Tuple[Any, ...]]]: | ||||||
|         ipvo = tuple([x for x in Utils.unique_seq(ipvo) if x in (4, 6)]) |         # If __ip_version_preference has only one entry, then it means that ONLY that IP version should be used. | ||||||
|         ipvo_len = len(ipvo) |         if len(self.__ip_version_preference) == 1: | ||||||
|         prefer_ipvo = ipvo_len > 0 |             family = socket.AF_INET if self.__ip_version_preference[0] == 4 else socket.AF_INET6 | ||||||
|         prefer_ipv4 = prefer_ipvo and ipvo[0] == 4 |  | ||||||
|         if ipvo_len == 1: |  | ||||||
|             family = socket.AF_INET if ipvo[0] == 4 else socket.AF_INET6 |  | ||||||
|         else: |         else: | ||||||
|             family = socket.AF_UNSPEC |             family = socket.AF_UNSPEC | ||||||
|         try: |         try: | ||||||
|             stype = socket.SOCK_STREAM |             stype = socket.SOCK_STREAM | ||||||
|             r = socket.getaddrinfo(self.__host, self.__port, family, stype) |             r = socket.getaddrinfo(self.__host, self.__port, family, stype) | ||||||
|             if prefer_ipvo: |  | ||||||
|                 r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4) |             # If the user has a preference for using IPv4 over IPv6 (or vice-versa), then sort the list returned by getaddrinfo() so that the preferred address type comes first. | ||||||
|             check = any(stype == rline[2] for rline in r) |             if len(self.__ip_version_preference) == 2: | ||||||
|  |                 r = sorted(r, key=lambda x: x[0], reverse=(self.__ip_version_preference[0] == 6)) | ||||||
|             for af, socktype, _proto, _canonname, addr in r: |             for af, socktype, _proto, _canonname, addr in r: | ||||||
|                 if not check or socktype == socket.SOCK_STREAM: |                 if socktype == socket.SOCK_STREAM: | ||||||
|                     yield af, addr |                     yield af, addr | ||||||
|         except socket.error as e: |         except socket.error as e: | ||||||
|             OutputBuffer().fail('[exception] {}'.format(e)).write() |             OutputBuffer().fail('[exception] {}'.format(e)).write() | ||||||
| @@ -156,7 +151,7 @@ class SSH_Socket(ReadBuf, WriteBuf): | |||||||
|     def connect(self) -> Optional[str]: |     def connect(self) -> Optional[str]: | ||||||
|         '''Returns None on success, or an error string.''' |         '''Returns None on success, or an error string.''' | ||||||
|         err = None |         err = None | ||||||
|         for af, addr in self._resolve(self.__ipvo): |         for af, addr in self._resolve(): | ||||||
|             s = None |             s = None | ||||||
|             try: |             try: | ||||||
|                 s = socket.socket(af, socket.SOCK_STREAM) |                 s = socket.socket(af, socket.SOCK_STREAM) | ||||||
|   | |||||||
| @@ -22,9 +22,8 @@ class TestAuditConf: | |||||||
|             'colors': True, |             'colors': True, | ||||||
|             'verbose': False, |             'verbose': False, | ||||||
|             'level': 'info', |             'level': 'info', | ||||||
|             'ipv4': True, |             'ipv4': False, | ||||||
|             'ipv6': True, |             'ipv6': False | ||||||
|             'ipvo': () |  | ||||||
|         } |         } | ||||||
|         for k, v in kwargs.items(): |         for k, v in kwargs.items(): | ||||||
|             options[k] = v |             options[k] = v | ||||||
| @@ -38,7 +37,6 @@ class TestAuditConf: | |||||||
|         assert conf.level == options['level'] |         assert conf.level == options['level'] | ||||||
|         assert conf.ipv4 == options['ipv4'] |         assert conf.ipv4 == options['ipv4'] | ||||||
|         assert conf.ipv6 == options['ipv6'] |         assert conf.ipv6 == options['ipv6'] | ||||||
|         assert conf.ipvo == options['ipvo'] |  | ||||||
|  |  | ||||||
|     def test_audit_conf_defaults(self): |     def test_audit_conf_defaults(self): | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
| @@ -64,57 +62,38 @@ class TestAuditConf: | |||||||
|                 conf.port = port |                 conf.port = port | ||||||
|             excinfo.match(r'.*invalid port.*') |             excinfo.match(r'.*invalid port.*') | ||||||
|  |  | ||||||
|     def test_audit_conf_ipvo(self): |     def test_audit_conf_ip_version_preference(self): | ||||||
|         # ipv4-only |         # ipv4-only | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         assert conf.ipv4 is True |         assert conf.ipv4 is True | ||||||
|         assert conf.ipv6 is False |         assert conf.ipv6 is False | ||||||
|         assert conf.ipvo == (4,) |         assert conf.ip_version_preference == [4] | ||||||
|         # ipv6-only |         # ipv6-only | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         assert conf.ipv4 is False |         assert conf.ipv4 is False | ||||||
|         assert conf.ipv6 is True |         assert conf.ipv6 is True | ||||||
|         assert conf.ipvo == (6,) |         assert conf.ip_version_preference == [6] | ||||||
|         # ipv4-only (by removing ipv6) |  | ||||||
|         conf = self.AuditConf() |  | ||||||
|         conf.ipv6 = False |  | ||||||
|         assert conf.ipv4 is True |  | ||||||
|         assert conf.ipv6 is False |  | ||||||
|         assert conf.ipvo == (4, ) |  | ||||||
|         # ipv6-only (by removing ipv4) |  | ||||||
|         conf = self.AuditConf() |  | ||||||
|         conf.ipv4 = False |  | ||||||
|         assert conf.ipv4 is False |  | ||||||
|         assert conf.ipv6 is True |  | ||||||
|         assert conf.ipvo == (6, ) |  | ||||||
|         # ipv4-preferred |         # ipv4-preferred | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         assert conf.ipv4 is True |         assert conf.ipv4 is True | ||||||
|         assert conf.ipv6 is True |         assert conf.ipv6 is True | ||||||
|         assert conf.ipvo == (4, 6) |         assert conf.ip_version_preference == [4, 6] | ||||||
|         # ipv6-preferred |         # ipv6-preferred | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         assert conf.ipv4 is True |         assert conf.ipv4 is True | ||||||
|         assert conf.ipv6 is True |         assert conf.ipv6 is True | ||||||
|         assert conf.ipvo == (6, 4) |         assert conf.ip_version_preference == [6, 4] | ||||||
|         # ipvo empty |         # defaults | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|         conf.ipvo = () |         assert conf.ipv4 is False | ||||||
|         assert conf.ipv4 is True |         assert conf.ipv6 is False | ||||||
|         assert conf.ipv6 is True |         assert conf.ip_version_preference == [] | ||||||
|         assert conf.ipvo == () |  | ||||||
|         # ipvo validation |  | ||||||
|         conf = self.AuditConf() |  | ||||||
|         conf.ipvo = (1, 2, 3, 4, 5, 6) |  | ||||||
|         assert conf.ipvo == (4, 6) |  | ||||||
|         conf.ipvo = (4, 4, 4, 6, 6) |  | ||||||
|         assert conf.ipvo == (4, 6) |  | ||||||
|  |  | ||||||
|     def test_audit_conf_level(self): |     def test_audit_conf_level(self): | ||||||
|         conf = self.AuditConf() |         conf = self.AuditConf() | ||||||
|   | |||||||
| @@ -19,11 +19,11 @@ class TestResolve: | |||||||
|     def test_resolve_error(self, output_spy, virtual_socket): |     def test_resolve_error(self, output_spy, virtual_socket): | ||||||
|         vsocket = virtual_socket |         vsocket = virtual_socket | ||||||
|         vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') |         vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known') | ||||||
|         s = self.ssh_socket('localhost', 22) |  | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|  |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|         output_spy.begin() |         output_spy.begin() | ||||||
|         with pytest.raises(SystemExit): |         with pytest.raises(SystemExit): | ||||||
|             list(s._resolve(conf.ipvo)) |             list(s._resolve()) | ||||||
|         lines = output_spy.flush() |         lines = output_spy.flush() | ||||||
|         assert len(lines) == 1 |         assert len(lines) == 1 | ||||||
|         assert 'hostname nor servname provided' in lines[-1] |         assert 'hostname nor servname provided' in lines[-1] | ||||||
| @@ -31,49 +31,50 @@ class TestResolve: | |||||||
|     def test_resolve_hostname_without_records(self, output_spy, virtual_socket): |     def test_resolve_hostname_without_records(self, output_spy, virtual_socket): | ||||||
|         vsocket = virtual_socket |         vsocket = virtual_socket | ||||||
|         vsocket.gsock.addrinfodata['localhost#22'] = [] |         vsocket.gsock.addrinfodata['localhost#22'] = [] | ||||||
|         s = self.ssh_socket('localhost', 22) |  | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|  |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|         output_spy.begin() |         output_spy.begin() | ||||||
|         r = list(s._resolve(conf.ipvo)) |         r = list(s._resolve()) | ||||||
|         assert len(r) == 0 |         assert len(r) == 0 | ||||||
|  |  | ||||||
|     def test_resolve_ipv4(self, virtual_socket): |     def test_resolve_ipv4(self, virtual_socket): | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         s = self.ssh_socket('localhost', 22) |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|         r = list(s._resolve(conf.ipvo)) |         r = list(s._resolve()) | ||||||
|         assert len(r) == 1 |         assert len(r) == 1 | ||||||
|         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) |         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) | ||||||
|  |  | ||||||
|     def test_resolve_ipv6(self, virtual_socket): |     def test_resolve_ipv6(self, virtual_socket): | ||||||
|         s = self.ssh_socket('localhost', 22) |  | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         r = list(s._resolve(conf.ipvo)) |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|  |         r = list(s._resolve()) | ||||||
|         assert len(r) == 1 |         assert len(r) == 1 | ||||||
|         assert r[0] == (socket.AF_INET6, ('::1', 22)) |         assert r[0] == (socket.AF_INET6, ('::1', 22)) | ||||||
|  |  | ||||||
|     def test_resolve_ipv46_both(self, virtual_socket): |     def test_resolve_ipv46_both(self, virtual_socket): | ||||||
|         s = self.ssh_socket('localhost', 22) |  | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|         r = list(s._resolve(conf.ipvo)) |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|  |         r = list(s._resolve()) | ||||||
|         assert len(r) == 2 |         assert len(r) == 2 | ||||||
|         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) |         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) | ||||||
|         assert r[1] == (socket.AF_INET6, ('::1', 22)) |         assert r[1] == (socket.AF_INET6, ('::1', 22)) | ||||||
|  |  | ||||||
|     def test_resolve_ipv46_order(self, virtual_socket): |     def test_resolve_ipv46_order(self, virtual_socket): | ||||||
|         s = self.ssh_socket('localhost', 22) |  | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         r = list(s._resolve(conf.ipvo)) |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|  |         r = list(s._resolve()) | ||||||
|         assert len(r) == 2 |         assert len(r) == 2 | ||||||
|         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) |         assert r[0] == (socket.AF_INET, ('127.0.0.1', 22)) | ||||||
|         assert r[1] == (socket.AF_INET6, ('::1', 22)) |         assert r[1] == (socket.AF_INET6, ('::1', 22)) | ||||||
|         conf = self._conf() |         conf = self._conf() | ||||||
|         conf.ipv6 = True |         conf.ipv6 = True | ||||||
|         conf.ipv4 = True |         conf.ipv4 = True | ||||||
|         r = list(s._resolve(conf.ipvo)) |         s = self.ssh_socket('localhost', 22, conf.ip_version_preference) | ||||||
|  |         r = list(s._resolve()) | ||||||
|         assert len(r) == 2 |         assert len(r) == 2 | ||||||
|         assert r[0] == (socket.AF_INET6, ('::1', 22)) |         assert r[0] == (socket.AF_INET6, ('::1', 22)) | ||||||
|         assert r[1] == (socket.AF_INET, ('127.0.0.1', 22)) |         assert r[1] == (socket.AF_INET, ('127.0.0.1', 22)) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Joe Testa
					Joe Testa