import os import io import sys import socket import pytest @pytest.fixture(scope='module') def ssh_audit(): __rdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..') sys.path.append(os.path.abspath(__rdir)) return __import__('ssh-audit') # pylint: disable=attribute-defined-outside-init class _OutputSpy(list): def begin(self): self.__out = io.StringIO() self.__old_stdout = sys.stdout sys.stdout = self.__out def flush(self): lines = self.__out.getvalue().splitlines() sys.stdout = self.__old_stdout self.__out = None return lines @pytest.fixture(scope='module') def output_spy(): return _OutputSpy() class _VirtualGlobalSocket: def __init__(self, vsocket): self.vsocket = vsocket self.addrinfodata = {} # pylint: disable=unused-argument def create_connection(self, address, timeout=0, source_address=None): # pylint: disable=protected-access return self.vsocket._connect(address, True) # pylint: disable=unused-argument def socket(self, family=socket.AF_INET, socktype=socket.SOCK_STREAM, proto=0, fileno=None): return self.vsocket def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0): key = '{}#{}'.format(host, port) if key in self.addrinfodata: data = self.addrinfodata[key] if isinstance(data, Exception): raise data return data if host == 'localhost': r = [] if family in (0, socket.AF_INET): r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port))) if family in (0, socket.AF_INET6): r.append((socket.AF_INET6, 1, 6, '', ('::1', port))) return r return [] class _VirtualSocket: def __init__(self): self.sock_address = ('127.0.0.1', 0) self.peer_address = None self._connected = False self.timeout = -1.0 self.rdata = [] self.sdata = [] self.errors = {} self.gsock = _VirtualGlobalSocket(self) def _check_err(self, method): method_error = self.errors.get(method) if method_error: raise method_error def connect(self, address): return self._connect(address, False) def _connect(self, address, ret=True): self.peer_address = address self._connected = True self._check_err('connect') return self if ret else None def settimeout(self, timeout): self.timeout = timeout def gettimeout(self): return self.timeout def getpeername(self): if self.peer_address is None or not self._connected: raise OSError(57, 'Socket is not connected') return self.peer_address def getsockname(self): return self.sock_address def bind(self, address): self.sock_address = address def listen(self, backlog): pass def accept(self): # pylint: disable=protected-access conn = _VirtualSocket() conn.sock_address = self.sock_address conn.peer_address = ('127.0.0.1', 0) conn._connected = True return conn, conn.peer_address def recv(self, bufsize, flags=0): # pylint: disable=unused-argument if not self._connected: raise OSError(54, 'Connection reset by peer') if not len(self.rdata) > 0: return b'' data = self.rdata.pop(0) if isinstance(data, Exception): raise data return data def send(self, data): if self.peer_address is None or not self._connected: raise OSError(32, 'Broken pipe') self._check_err('send') self.sdata.append(data) @pytest.fixture() def virtual_socket(monkeypatch): vsocket = _VirtualSocket() gsock = vsocket.gsock monkeypatch.setattr(socket, 'create_connection', gsock.create_connection) monkeypatch.setattr(socket, 'socket', gsock.socket) monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo) return vsocket