mirror of https://github.com/jtesta/ssh-audit.git
Create virtual socket fixture (socket mocking).
This commit is contained in:
parent
63a9c479a7
commit
f065118959
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import pytest, os, sys, io
|
import pytest, os, sys, io, socket
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
|
@ -33,3 +33,80 @@ class _OutputSpy(list):
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def output_spy():
|
def output_spy():
|
||||||
return _OutputSpy()
|
return _OutputSpy()
|
||||||
|
|
||||||
|
|
||||||
|
class _VirtualSocket(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.sock_address = ('127.0.0.1', 0)
|
||||||
|
self.peer_address = None
|
||||||
|
self._connected = False
|
||||||
|
self.timeout = -1.0
|
||||||
|
self.rdata = []
|
||||||
|
self.sdata = []
|
||||||
|
self.errors = {}
|
||||||
|
|
||||||
|
def _check_err(self, method):
|
||||||
|
method_error = self.errors.get(method)
|
||||||
|
if method_error:
|
||||||
|
raise method_error
|
||||||
|
|
||||||
|
def _connect(self, address):
|
||||||
|
self.peer_address = address
|
||||||
|
self._connected = True
|
||||||
|
self._check_err('connect')
|
||||||
|
return self
|
||||||
|
|
||||||
|
def settimeout(self, timeout):
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def gettimeout(self):
|
||||||
|
return self.timeout
|
||||||
|
|
||||||
|
def getpeername(self):
|
||||||
|
if self.peer_address is None or not self._connected:
|
||||||
|
raise socket.error(57, 'Socket is not connected')
|
||||||
|
return self.peer_address
|
||||||
|
|
||||||
|
def getsockname(self):
|
||||||
|
return self.sock_address
|
||||||
|
|
||||||
|
def bind(self, address):
|
||||||
|
self.sock_address = address
|
||||||
|
|
||||||
|
def listen(self, backlog):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def accept(self):
|
||||||
|
conn = _VirtualSocket()
|
||||||
|
conn.sock_address = self.sock_address
|
||||||
|
conn.peer_address = ('127.0.0.1', 0)
|
||||||
|
conn._connected = True
|
||||||
|
return conn, conn.peer_address
|
||||||
|
|
||||||
|
def recv(self, bufsize, flags=0):
|
||||||
|
if not self._connected:
|
||||||
|
raise socket.error(54, 'Connection reset by peer')
|
||||||
|
if not len(self.rdata) > 0:
|
||||||
|
return b''
|
||||||
|
data = self.rdata.pop(0)
|
||||||
|
if isinstance(data, Exception):
|
||||||
|
raise data
|
||||||
|
return data
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
if self.peer_address is None or not self._connected:
|
||||||
|
raise socket.error(32, 'Broken pipe')
|
||||||
|
self._check_err('send')
|
||||||
|
self.sdata.append(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def virtual_socket(monkeypatch):
|
||||||
|
vsocket = _VirtualSocket()
|
||||||
|
def _c(address):
|
||||||
|
return vsocket._connect(address)
|
||||||
|
def _cc(address, timeout=0, source_address=None):
|
||||||
|
return vsocket._connect(address)
|
||||||
|
monkeypatch.setattr(socket, 'create_connection', _cc)
|
||||||
|
monkeypatch.setattr(socket.socket, 'connect', _c)
|
||||||
|
return vsocket
|
||||||
|
|
Loading…
Reference in New Issue