mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-10-30 21:15:27 +01:00 
			
		
		
		
	Implement new options (-4/--ipv4, -6/--ipv6, -p/--port <port>).
By default both IPv4 and IPv6 is supported and order of precedence depends on OS. By using -46, IPv4 is prefered, but by using -64, IPv6 is preferd. For now the old way how to specify port (host:port) has been kept intact.
This commit is contained in:
		
							
								
								
									
										163
									
								
								ssh-audit.py
									
									
									
									
									
								
							
							
						
						
									
										163
									
								
								ssh-audit.py
									
									
									
									
									
								
							| @@ -28,21 +28,22 @@ import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 | |||||||
|  |  | ||||||
| VERSION = 'v1.6.1.dev' | VERSION = 'v1.6.1.dev' | ||||||
|  |  | ||||||
| if sys.version_info >= (3,): | if sys.version_info >= (3,):  # pragma: nocover | ||||||
| 	StringIO, BytesIO = io.StringIO, io.BytesIO | 	StringIO, BytesIO = io.StringIO, io.BytesIO | ||||||
| 	text_type = str | 	text_type = str | ||||||
| 	binary_type = bytes | 	binary_type = bytes | ||||||
| else: | else:  # pragma: nocover | ||||||
| 	import StringIO as _StringIO  # pylint: disable=import-error | 	import StringIO as _StringIO  # pylint: disable=import-error | ||||||
| 	StringIO = BytesIO = _StringIO.StringIO | 	StringIO = BytesIO = _StringIO.StringIO | ||||||
| 	text_type = unicode  # pylint: disable=undefined-variable | 	text_type = unicode  # pylint: disable=undefined-variable | ||||||
| 	binary_type = str | 	binary_type = str | ||||||
| try: | try:  # pragma: nocover | ||||||
| 	# pylint: disable=unused-import | 	# pylint: disable=unused-import | ||||||
| 	from typing import List, Tuple, Optional, Callable, Union, Any | 	from typing import List, Set, Sequence, Tuple, Iterable | ||||||
|  | 	from typing import Callable, Optional, Union, Any | ||||||
| except ImportError:  # pragma: nocover | except ImportError:  # pragma: nocover | ||||||
| 	pass | 	pass | ||||||
| try: | try:  # pragma: nocover | ||||||
| 	from colorama import init as colorama_init | 	from colorama import init as colorama_init | ||||||
| 	colorama_init()  # pragma: nocover | 	colorama_init()  # pragma: nocover | ||||||
| except ImportError:  # pragma: nocover | except ImportError:  # pragma: nocover | ||||||
| @@ -53,13 +54,16 @@ def usage(err=None): | |||||||
| 	# type: (Optional[str]) -> None | 	# type: (Optional[str]) -> None | ||||||
| 	uout = Output() | 	uout = Output() | ||||||
| 	p = os.path.basename(sys.argv[0]) | 	p = os.path.basename(sys.argv[0]) | ||||||
| 	uout.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION)) | 	uout.head('# {0} {1}, moo@arthepsy.eu\n'.format(p, VERSION)) | ||||||
| 	if err is not None: | 	if err is not None: | ||||||
| 		uout.fail('\n' + err) | 		uout.fail('\n' + err) | ||||||
| 	uout.info('\nusage: {0} [-12bnv] [-l <level>] <host[:port]>\n'.format(p)) | 	uout.info('usage: {0} [-1246pbnvl] <host>\n'.format(p)) | ||||||
| 	uout.info('   -h,  --help             print this help') | 	uout.info('   -h,  --help             print this help') | ||||||
| 	uout.info('   -1,  --ssh1             force ssh version 1 only') | 	uout.info('   -1,  --ssh1             force ssh version 1 only') | ||||||
| 	uout.info('   -2,  --ssh2             force ssh version 2 only') | 	uout.info('   -2,  --ssh2             force ssh version 2 only') | ||||||
|  | 	uout.info('   -4,  --ipv4             enable IPv4 (order of precedence)') | ||||||
|  | 	uout.info('   -6,  --ipv6             enable IPv6 (order of precedence)') | ||||||
|  | 	uout.info('   -p,  --port=<port>      port to connect') | ||||||
| 	uout.info('   -b,  --batch            batch output') | 	uout.info('   -b,  --batch            batch output') | ||||||
| 	uout.info('   -n,  --no-colors        disable colors') | 	uout.info('   -n,  --no-colors        disable colors') | ||||||
| 	uout.info('   -v,  --verbose          verbose output') | 	uout.info('   -v,  --verbose          verbose output') | ||||||
| @@ -69,6 +73,7 @@ def usage(err=None): | |||||||
|  |  | ||||||
|  |  | ||||||
| class AuditConf(object): | class AuditConf(object): | ||||||
|  | 	# pylint: disable=too-many-instance-attributes | ||||||
| 	def __init__(self, host=None, port=22): | 	def __init__(self, host=None, port=22): | ||||||
| 		# type: (Optional[str], int) -> None | 		# type: (Optional[str], int) -> None | ||||||
| 		self.host = host | 		self.host = host | ||||||
| @@ -79,12 +84,35 @@ class AuditConf(object): | |||||||
| 		self.colors = True | 		self.colors = True | ||||||
| 		self.verbose = False | 		self.verbose = False | ||||||
| 		self.minlevel = 'info' | 		self.minlevel = 'info' | ||||||
|  | 		self.ipvo = ()  # type: Sequence[int] | ||||||
|  | 		self.ipv4 = False | ||||||
|  | 		self.ipv6 = False | ||||||
| 	 | 	 | ||||||
| 	def __setattr__(self, name, value): | 	def __setattr__(self, name, value): | ||||||
| 		# type: (str, Union[str, int, bool]) -> None | 		# type: (str, Union[str, int, bool, Sequence[int]]) -> None | ||||||
| 		valid = False | 		valid = False | ||||||
| 		if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: | 		if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']: | ||||||
| 			valid, value = True, True if value else False | 			valid, value = True, True if value else False | ||||||
|  | 		elif name in ['ipv4', 'ipv6']: | ||||||
|  | 			valid = False | ||||||
|  | 			value = True if value else False | ||||||
|  | 			ipv = 4 if name == 'ipv4' else 6 | ||||||
|  | 			if value: | ||||||
|  | 				value = tuple(list(self.ipvo) + [ipv]) | ||||||
|  | 			else: | ||||||
|  | 				if len(self.ipvo) == 0: | ||||||
|  | 					value = (6,) if ipv == 4 else (4,) | ||||||
|  | 				else: | ||||||
|  | 					value = tuple(filter(lambda x: x != ipv, self.ipvo)) | ||||||
|  | 			self.__setattr__('ipvo', value) | ||||||
|  | 		elif name == 'ipvo': | ||||||
|  | 			if isinstance(value, (tuple, list)): | ||||||
|  | 				uniq_value = utils.unique_seq(value) | ||||||
|  | 				value = tuple(filter(lambda x: x in (4, 6), uniq_value)) | ||||||
|  | 				valid = True | ||||||
|  | 				ipv_both = len(value) == 0 | ||||||
|  | 				object.__setattr__(self, 'ipv4', ipv_both or 4 in value) | ||||||
|  | 				object.__setattr__(self, 'ipv6', ipv_both or 6 in value) | ||||||
| 		elif name == 'port': | 		elif name == 'port': | ||||||
| 			valid, port = True, utils.parse_int(value) | 			valid, port = True, utils.parse_int(value) | ||||||
| 			if port < 1 or port > 65535: | 			if port < 1 or port > 65535: | ||||||
| @@ -105,13 +133,14 @@ class AuditConf(object): | |||||||
| 		# pylint: disable=too-many-branches | 		# pylint: disable=too-many-branches | ||||||
| 		aconf = cls() | 		aconf = cls() | ||||||
| 		try: | 		try: | ||||||
| 			sopts = 'h12bnvl:' | 			sopts = 'h1246p:bnvl:' | ||||||
| 			lopts = ['help', 'ssh1', 'ssh2', 'batch', | 			lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port', | ||||||
| 			         'no-colors', 'verbose', 'level='] | 			         'batch', 'no-colors', 'verbose', 'level='] | ||||||
| 			opts, args = getopt.getopt(args, sopts, lopts) | 			opts, args = getopt.getopt(args, sopts, lopts) | ||||||
| 		except getopt.GetoptError as err: | 		except getopt.GetoptError as err: | ||||||
| 			usage_cb(str(err)) | 			usage_cb(str(err)) | ||||||
| 		aconf.ssh1, aconf.ssh2 = False, False | 		aconf.ssh1, aconf.ssh2 = False, False | ||||||
|  | 		oport = None | ||||||
| 		for o, a in opts: | 		for o, a in opts: | ||||||
| 			if o in ('-h', '--help'): | 			if o in ('-h', '--help'): | ||||||
| 				usage_cb() | 				usage_cb() | ||||||
| @@ -119,6 +148,12 @@ class AuditConf(object): | |||||||
| 				aconf.ssh1 = True | 				aconf.ssh1 = True | ||||||
| 			elif o in ('-2', '--ssh2'): | 			elif o in ('-2', '--ssh2'): | ||||||
| 				aconf.ssh2 = True | 				aconf.ssh2 = True | ||||||
|  | 			elif o in ('-4', '--ipv4'): | ||||||
|  | 				aconf.ipv4 = True | ||||||
|  | 			elif o in ('-6', '--ipv6'): | ||||||
|  | 				aconf.ipv6 = True | ||||||
|  | 			elif o in ('-p', '--port'): | ||||||
|  | 				oport = a | ||||||
| 			elif o in ('-b', '--batch'): | 			elif o in ('-b', '--batch'): | ||||||
| 				aconf.batch = True | 				aconf.batch = True | ||||||
| 				aconf.verbose = True | 				aconf.verbose = True | ||||||
| @@ -132,14 +167,20 @@ class AuditConf(object): | |||||||
| 				aconf.minlevel = a | 				aconf.minlevel = a | ||||||
| 		if len(args) == 0: | 		if len(args) == 0: | ||||||
| 			usage_cb() | 			usage_cb() | ||||||
|  | 		if oport is not None: | ||||||
|  | 			host = args[0] | ||||||
|  | 			port = utils.parse_int(oport) | ||||||
|  | 		else: | ||||||
| 			s = args[0].split(':') | 			s = args[0].split(':') | ||||||
| 		host, port = s[0].strip(), 22 | 			host = s[0].strip() | ||||||
| 		if len(s) > 1: | 			if len(s) == 2: | ||||||
| 			port = utils.parse_int(s[1]) | 				oport, port = s[1], utils.parse_int(s[1]) | ||||||
|  | 			else: | ||||||
|  | 				oport, port = '22', 22 | ||||||
| 		if not host: | 		if not host: | ||||||
| 			usage_cb('host is empty') | 			usage_cb('host is empty') | ||||||
| 		if port <= 0 or port > 65535: | 		if port <= 0 or port > 65535: | ||||||
| 			usage_cb('port {0} is not valid'.format(s[1])) | 			usage_cb('port {0} is not valid'.format(oport)) | ||||||
| 		aconf.host = host | 		aconf.host = host | ||||||
| 		aconf.port = port | 		aconf.port = port | ||||||
| 		if not (aconf.ssh1 or aconf.ssh2): | 		if not (aconf.ssh1 or aconf.ssh2): | ||||||
| @@ -1038,24 +1079,67 @@ class SSH(object):  # pylint: disable=too-few-public-methods | |||||||
| 		 | 		 | ||||||
| 		SM_BANNER_SENT = 1 | 		SM_BANNER_SENT = 1 | ||||||
| 		 | 		 | ||||||
| 		def __init__(self, host, port, cto=3.0, rto=5.0): | 		def __init__(self, host, port): | ||||||
| 			# type: (str, int, float, float) -> None | 			# type: (str, int) -> None | ||||||
|  | 			super(SSH.Socket, self).__init__() | ||||||
| 			self.__block_size = 8 | 			self.__block_size = 8 | ||||||
| 			self.__state = 0 | 			self.__state = 0 | ||||||
| 			self.__header = []  # type: List[text_type] | 			self.__header = []  # type: List[text_type] | ||||||
| 			self.__banner = None  # type: Optional[SSH.Banner] | 			self.__banner = None  # type: Optional[SSH.Banner] | ||||||
| 			super(SSH.Socket, self).__init__() | 			self.__host = host | ||||||
| 			try: | 			self.__port = port | ||||||
| 				self.__sock = socket.create_connection((host, port), cto) | 			self.__sock = None  # type: socket.socket | ||||||
| 				self.__sock.settimeout(rto) |  | ||||||
| 			except Exception as e:  # pylint: disable=broad-except |  | ||||||
| 				out.fail('[fail] {0}'.format(e)) |  | ||||||
| 				sys.exit(1) |  | ||||||
| 		 | 		 | ||||||
| 		def __enter__(self): | 		def __enter__(self): | ||||||
| 			# type: () -> SSH.Socket | 			# type: () -> SSH.Socket | ||||||
| 			return self | 			return self | ||||||
| 		 | 		 | ||||||
|  | 		def _resolve(self, ipvo): | ||||||
|  | 			# type: (Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]] | ||||||
|  | 			ipvo = tuple(filter(lambda x: x in (4, 6), utils.unique_seq(ipvo))) | ||||||
|  | 			ipvo_len = len(ipvo) | ||||||
|  | 			prefer_ipvo = ipvo_len > 0 | ||||||
|  | 			prefer_ipv4 = prefer_ipvo and ipvo[0] == 4 | ||||||
|  | 			if len(ipvo) == 1: | ||||||
|  | 				family = {4: socket.AF_INET, 6: socket.AF_INET6}.get(ipvo[0]) | ||||||
|  | 			else: | ||||||
|  | 				family = socket.AF_UNSPEC | ||||||
|  | 			try: | ||||||
|  | 				stype = socket.SOCK_STREAM | ||||||
|  | 				r = socket.getaddrinfo(self.__host, self.__port, family, stype) | ||||||
|  | 				if prefer_ipvo: | ||||||
|  | 					r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4) | ||||||
|  | 				check = any(stype == rline[2] for rline in r) | ||||||
|  | 				for (af, socktype, proto, canonname, addr) in r: | ||||||
|  | 					if not check or socktype == socket.SOCK_STREAM: | ||||||
|  | 						yield (af, addr) | ||||||
|  | 			except socket.error as e: | ||||||
|  | 				out.fail('[exception] {0}'.format(e)) | ||||||
|  | 				sys.exit(1) | ||||||
|  | 		 | ||||||
|  | 		def connect(self, ipvo=(), cto=3.0, rto=5.0): | ||||||
|  | 			# type: (Sequence[int], float, float) -> None | ||||||
|  | 			err = None | ||||||
|  | 			for (af, addr) in self._resolve(ipvo): | ||||||
|  | 				s = None | ||||||
|  | 				try: | ||||||
|  | 					s = socket.socket(af, socket.SOCK_STREAM) | ||||||
|  | 					s.settimeout(cto) | ||||||
|  | 					s.connect(addr) | ||||||
|  | 					s.settimeout(rto) | ||||||
|  | 					self.__sock = s | ||||||
|  | 					return | ||||||
|  | 				except socket.error as e: | ||||||
|  | 					err = e | ||||||
|  | 					self._close_socket(s) | ||||||
|  | 			if err is None: | ||||||
|  | 				errm = 'host {0} has no DNS records'.format(self.__host) | ||||||
|  | 			else: | ||||||
|  | 				errt = (self.__host, self.__port, err) | ||||||
|  | 				errm = 'cannot connect to {0} port {1}: {2}'.format(*errt) | ||||||
|  | 			out.fail('[exception] {0}'.format(errm)) | ||||||
|  | 			sys.exit(1) | ||||||
|  | 		 | ||||||
| 		def get_banner(self, sshv=2): | 		def get_banner(self, sshv=2): | ||||||
| 			# type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] | 			# type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]] | ||||||
| 			banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') | 			banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0') | ||||||
| @@ -1188,6 +1272,15 @@ class SSH(object):  # pylint: disable=too-few-public-methods | |||||||
| 			data = struct.pack('>Ib', plen, padding) + payload + pad_bytes | 			data = struct.pack('>Ib', plen, padding) + payload + pad_bytes | ||||||
| 			return self.send(data) | 			return self.send(data) | ||||||
| 		 | 		 | ||||||
|  | 		def _close_socket(self, s): | ||||||
|  | 			# type: (Optional[socket.socket]) -> None | ||||||
|  | 			try: | ||||||
|  | 				if s is not None: | ||||||
|  | 					s.shutdown(socket.SHUT_RDWR) | ||||||
|  | 					s.close() | ||||||
|  | 			except:  # pylint: disable=bare-except | ||||||
|  | 				pass | ||||||
|  | 		 | ||||||
| 		def __del__(self): | 		def __del__(self): | ||||||
| 			# type: () -> None | 			# type: () -> None | ||||||
| 			self.__cleanup() | 			self.__cleanup() | ||||||
| @@ -1198,11 +1291,7 @@ class SSH(object):  # pylint: disable=too-few-public-methods | |||||||
| 		 | 		 | ||||||
| 		def __cleanup(self): | 		def __cleanup(self): | ||||||
| 			# type: () -> None | 			# type: () -> None | ||||||
| 			try: | 			self._close_socket(self.__sock) | ||||||
| 				self.__sock.shutdown(socket.SHUT_RDWR) |  | ||||||
| 				self.__sock.close() |  | ||||||
| 			except:  # pylint: disable=bare-except |  | ||||||
| 				pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class KexDH(object): | class KexDH(object): | ||||||
| @@ -1847,6 +1936,21 @@ class Utils(object): | |||||||
| 			return cls.to_ntext(v.encode('ascii', errors)) | 			return cls.to_ntext(v.encode('ascii', errors)) | ||||||
| 		raise cls._type_err(v, 'ascii') | 		raise cls._type_err(v, 'ascii') | ||||||
| 	 | 	 | ||||||
|  | 	@classmethod | ||||||
|  | 	def unique_seq(cls, seq): | ||||||
|  | 		# type: (Sequence[Any]) -> Sequence[Any] | ||||||
|  | 		seen = set()  # type: Set[Any] | ||||||
|  | 		 | ||||||
|  | 		def _seen_add(x): | ||||||
|  | 			# type: (Any) -> bool | ||||||
|  | 			seen.add(x) | ||||||
|  | 			return False | ||||||
|  | 		 | ||||||
|  | 		if isinstance(seq, tuple): | ||||||
|  | 			return tuple(x for x in seq if x not in seen and not _seen_add(x)) | ||||||
|  | 		else: | ||||||
|  | 			return [x for x in seq if x not in seen and not _seen_add(x)] | ||||||
|  | 		 | ||||||
| 	@staticmethod | 	@staticmethod | ||||||
| 	def parse_int(v): | 	def parse_int(v): | ||||||
| 		# type: (Any) -> int | 		# type: (Any) -> int | ||||||
| @@ -1863,6 +1967,7 @@ def audit(aconf, sshv=None): | |||||||
| 	out.verbose = aconf.verbose | 	out.verbose = aconf.verbose | ||||||
| 	out.minlevel = aconf.minlevel | 	out.minlevel = aconf.minlevel | ||||||
| 	s = SSH.Socket(aconf.host, aconf.port) | 	s = SSH.Socket(aconf.host, aconf.port) | ||||||
|  | 	s.connect(aconf.ipvo) | ||||||
| 	if sshv is None: | 	if sshv is None: | ||||||
| 		sshv = 2 if aconf.ssh2 else 1 | 		sshv = 2 if aconf.ssh2 else 1 | ||||||
| 	err = None | 	err = None | ||||||
|   | |||||||
| @@ -1,10 +1,14 @@ | |||||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||||
| import pytest, os, sys, io, socket | import os | ||||||
|  | import io | ||||||
|  | import sys | ||||||
|  | import socket | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  |  | ||||||
| if sys.version_info[0] == 2: | if sys.version_info[0] == 2: | ||||||
| 	import StringIO | 	import StringIO  # pylint: disable=import-error | ||||||
| 	StringIO = StringIO.StringIO | 	StringIO = StringIO.StringIO | ||||||
| else: | else: | ||||||
| 	StringIO = io.StringIO | 	StringIO = io.StringIO | ||||||
| @@ -17,6 +21,7 @@ def ssh_audit(): | |||||||
| 	return __import__('ssh-audit') | 	return __import__('ssh-audit') | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # pylint: disable=attribute-defined-outside-init | ||||||
| class _OutputSpy(list): | class _OutputSpy(list): | ||||||
| 	def begin(self): | 	def begin(self): | ||||||
| 		self.__out = StringIO() | 		self.__out = StringIO() | ||||||
| @@ -50,11 +55,14 @@ class _VirtualSocket(object): | |||||||
| 		if method_error: | 		if method_error: | ||||||
| 			raise method_error | 			raise method_error | ||||||
| 	 | 	 | ||||||
| 	def _connect(self, address): | 	def connect(self, address): | ||||||
|  | 		return self._connect(address, False) | ||||||
|  | 	 | ||||||
|  | 	def _connect(self, address, ret=True): | ||||||
| 		self.peer_address = address | 		self.peer_address = address | ||||||
| 		self._connected = True | 		self._connected = True | ||||||
| 		self._check_err('connect') | 		self._check_err('connect') | ||||||
| 		return self | 		return self if ret else None | ||||||
| 	 | 	 | ||||||
| 	def settimeout(self, timeout): | 	def settimeout(self, timeout): | ||||||
| 		self.timeout = timeout | 		self.timeout = timeout | ||||||
| @@ -77,6 +85,7 @@ class _VirtualSocket(object): | |||||||
| 		pass | 		pass | ||||||
| 	 | 	 | ||||||
| 	def accept(self): | 	def accept(self): | ||||||
|  | 		# pylint: disable=protected-access | ||||||
| 		conn = _VirtualSocket() | 		conn = _VirtualSocket() | ||||||
| 		conn.sock_address = self.sock_address | 		conn.sock_address = self.sock_address | ||||||
| 		conn.peer_address = ('127.0.0.1', 0) | 		conn.peer_address = ('127.0.0.1', 0) | ||||||
| @@ -84,6 +93,7 @@ class _VirtualSocket(object): | |||||||
| 		return conn, conn.peer_address | 		return conn, conn.peer_address | ||||||
| 	 | 	 | ||||||
| 	def recv(self, bufsize, flags=0): | 	def recv(self, bufsize, flags=0): | ||||||
|  | 		# pylint: disable=unused-argument | ||||||
| 		if not self._connected: | 		if not self._connected: | ||||||
| 			raise socket.error(54, 'Connection reset by peer') | 			raise socket.error(54, 'Connection reset by peer') | ||||||
| 		if not len(self.rdata) > 0: | 		if not len(self.rdata) > 0: | ||||||
| @@ -103,10 +113,18 @@ class _VirtualSocket(object): | |||||||
| @pytest.fixture() | @pytest.fixture() | ||||||
| def virtual_socket(monkeypatch): | def virtual_socket(monkeypatch): | ||||||
| 	vsocket = _VirtualSocket() | 	vsocket = _VirtualSocket() | ||||||
| 	def _c(address): | 	 | ||||||
| 		return vsocket._connect(address) | 	# pylint: disable=unused-argument | ||||||
| 	def _cc(address, timeout=0, source_address=None): | 	def _socket(family=socket.AF_INET, | ||||||
| 		return vsocket._connect(address) | 	            socktype=socket.SOCK_STREAM, | ||||||
| 	monkeypatch.setattr(socket, 'create_connection', _cc) | 	            proto=0, | ||||||
| 	monkeypatch.setattr(socket.socket, 'connect', _c) | 	            fileno=None): | ||||||
|  | 		return vsocket | ||||||
|  | 	 | ||||||
|  | 	def _cc(address, timeout=0, source_address=None): | ||||||
|  | 		# pylint: disable=protected-access | ||||||
|  | 		return vsocket._connect(address, True) | ||||||
|  | 	 | ||||||
|  | 	monkeypatch.setattr(socket, 'create_connection', _cc) | ||||||
|  | 	monkeypatch.setattr(socket, 'socket', _socket) | ||||||
| 	return vsocket | 	return vsocket | ||||||
|   | |||||||
| @@ -20,7 +20,10 @@ class TestAuditConf(object): | |||||||
| 			'batch': False, | 			'batch': False, | ||||||
| 			'colors': True, | 			'colors': True, | ||||||
| 			'verbose': False, | 			'verbose': False, | ||||||
| 			'minlevel': 'info' | 			'minlevel': 'info', | ||||||
|  | 			'ipv4': True, | ||||||
|  | 			'ipv6': True, | ||||||
|  | 			'ipvo': () | ||||||
| 		} | 		} | ||||||
| 		for k, v in kwargs.items(): | 		for k, v in kwargs.items(): | ||||||
| 			options[k] = v | 			options[k] = v | ||||||
| @@ -32,6 +35,9 @@ class TestAuditConf(object): | |||||||
| 		assert conf.colors is options['colors'] | 		assert conf.colors is options['colors'] | ||||||
| 		assert conf.verbose is options['verbose'] | 		assert conf.verbose is options['verbose'] | ||||||
| 		assert conf.minlevel == options['minlevel'] | 		assert conf.minlevel == options['minlevel'] | ||||||
|  | 		assert conf.ipv4 == options['ipv4'] | ||||||
|  | 		assert conf.ipv6 == options['ipv6'] | ||||||
|  | 		assert conf.ipvo == options['ipvo'] | ||||||
| 	 | 	 | ||||||
| 	def test_audit_conf_defaults(self): | 	def test_audit_conf_defaults(self): | ||||||
| 		conf = self.AuditConf() | 		conf = self.AuditConf() | ||||||
| @@ -57,6 +63,58 @@ class TestAuditConf(object): | |||||||
| 				conf.port = port | 				conf.port = port | ||||||
| 			excinfo.match(r'.*invalid port.*') | 			excinfo.match(r'.*invalid port.*') | ||||||
| 	 | 	 | ||||||
|  | 	def test_audit_conf_ipvo(self): | ||||||
|  | 		# ipv4-only | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv4 = True | ||||||
|  | 		assert conf.ipv4 is True | ||||||
|  | 		assert conf.ipv6 is False | ||||||
|  | 		assert conf.ipvo == (4,) | ||||||
|  | 		# ipv6-only | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv6 = True | ||||||
|  | 		assert conf.ipv4 is False | ||||||
|  | 		assert conf.ipv6 is True | ||||||
|  | 		assert conf.ipvo == (6,) | ||||||
|  | 		# ipv4-only (by removing ipv6) | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv6 = False | ||||||
|  | 		assert conf.ipv4 is True | ||||||
|  | 		assert conf.ipv6 is False | ||||||
|  | 		assert conf.ipvo == (4, ) | ||||||
|  | 		# ipv6-only (by removing ipv4) | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv4 = False | ||||||
|  | 		assert conf.ipv4 is False | ||||||
|  | 		assert conf.ipv6 is True | ||||||
|  | 		assert conf.ipvo == (6, ) | ||||||
|  | 		# ipv4-preferred | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv4 = True | ||||||
|  | 		conf.ipv6 = True | ||||||
|  | 		assert conf.ipv4 is True | ||||||
|  | 		assert conf.ipv6 is True | ||||||
|  | 		assert conf.ipvo == (4, 6) | ||||||
|  | 		# ipv6-preferred | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipv6 = True | ||||||
|  | 		conf.ipv4 = True | ||||||
|  | 		assert conf.ipv4 is True | ||||||
|  | 		assert conf.ipv6 is True | ||||||
|  | 		assert conf.ipvo == (6, 4) | ||||||
|  | 		# ipvo empty | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipvo = () | ||||||
|  | 		assert conf.ipv4 is True | ||||||
|  | 		assert conf.ipv6 is True | ||||||
|  | 		assert conf.ipvo == () | ||||||
|  | 		# ipvo validation | ||||||
|  | 		conf = self.AuditConf() | ||||||
|  | 		conf.ipvo = (1, 2, 3, 4, 5, 6) | ||||||
|  | 		assert conf.ipvo == (4, 6) | ||||||
|  | 		conf.ipvo = (4, 4, 4, 6, 6) | ||||||
|  | 		assert conf.ipvo == (4, 6) | ||||||
|  | 	 | ||||||
| 	def test_audit_conf_minlevel(self): | 	def test_audit_conf_minlevel(self): | ||||||
| 		conf = self.AuditConf() | 		conf = self.AuditConf() | ||||||
| 		for level in ['info', 'warn', 'fail']: | 		for level in ['info', 'warn', 'fail']: | ||||||
| @@ -68,6 +126,7 @@ class TestAuditConf(object): | |||||||
| 			excinfo.match(r'.*invalid level.*') | 			excinfo.match(r'.*invalid level.*') | ||||||
| 	 | 	 | ||||||
| 	def test_audit_conf_cmdline(self): | 	def test_audit_conf_cmdline(self): | ||||||
|  | 		# pylint: disable=too-many-statements | ||||||
| 		c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage)  # noqa | 		c = lambda x: self.AuditConf.from_cmdline(x.split(), self.usage)  # noqa | ||||||
| 		with pytest.raises(SystemExit): | 		with pytest.raises(SystemExit): | ||||||
| 			conf = c('') | 			conf = c('') | ||||||
| @@ -87,20 +146,36 @@ class TestAuditConf(object): | |||||||
| 		self._test_conf(conf, host='github.com') | 		self._test_conf(conf, host='github.com') | ||||||
| 		conf = c('localhost:2222') | 		conf = c('localhost:2222') | ||||||
| 		self._test_conf(conf, host='localhost', port=2222) | 		self._test_conf(conf, host='localhost', port=2222) | ||||||
|  | 		conf = c('-p 2222 localhost') | ||||||
|  | 		self._test_conf(conf, host='localhost', port=2222) | ||||||
| 		with pytest.raises(SystemExit): | 		with pytest.raises(SystemExit): | ||||||
| 			conf = c('localhost:') | 			conf = c('localhost:') | ||||||
| 		with pytest.raises(SystemExit): | 		with pytest.raises(SystemExit): | ||||||
| 			conf = c('localhost:abc') | 			conf = c('localhost:abc') | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			conf = c('-p abc localhost') | ||||||
| 		with pytest.raises(SystemExit): | 		with pytest.raises(SystemExit): | ||||||
| 			conf = c('localhost:-22') | 			conf = c('localhost:-22') | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			conf = c('-p -22 localhost') | ||||||
| 		with pytest.raises(SystemExit): | 		with pytest.raises(SystemExit): | ||||||
| 			conf = c('localhost:99999') | 			conf = c('localhost:99999') | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			conf = c('-p 99999 localhost') | ||||||
| 		conf = c('-1 localhost') | 		conf = c('-1 localhost') | ||||||
| 		self._test_conf(conf, host='localhost', ssh1=True, ssh2=False) | 		self._test_conf(conf, host='localhost', ssh1=True, ssh2=False) | ||||||
| 		conf = c('-2 localhost') | 		conf = c('-2 localhost') | ||||||
| 		self._test_conf(conf, host='localhost', ssh1=False, ssh2=True) | 		self._test_conf(conf, host='localhost', ssh1=False, ssh2=True) | ||||||
| 		conf = c('-12 localhost') | 		conf = c('-12 localhost') | ||||||
| 		self._test_conf(conf, host='localhost', ssh1=True, ssh2=True) | 		self._test_conf(conf, host='localhost', ssh1=True, ssh2=True) | ||||||
|  | 		conf = c('-4 localhost') | ||||||
|  | 		self._test_conf(conf, host='localhost', ipv4=True, ipv6=False, ipvo=(4,)) | ||||||
|  | 		conf = c('-6 localhost') | ||||||
|  | 		self._test_conf(conf, host='localhost', ipv4=False, ipv6=True, ipvo=(6,)) | ||||||
|  | 		conf = c('-46 localhost') | ||||||
|  | 		self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(4, 6)) | ||||||
|  | 		conf = c('-64 localhost') | ||||||
|  | 		self._test_conf(conf, host='localhost', ipv4=True, ipv6=True, ipvo=(6, 4)) | ||||||
| 		conf = c('-b localhost') | 		conf = c('-b localhost') | ||||||
| 		self._test_conf(conf, host='localhost', batch=True, verbose=True) | 		self._test_conf(conf, host='localhost', batch=True, verbose=True) | ||||||
| 		conf = c('-n localhost') | 		conf = c('-n localhost') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Andris Raugulis
					Andris Raugulis