Add resolve tests.

This commit is contained in:
Andris Raugulis 2016-11-02 19:28:16 +02:00
parent 6c4b9fcadf
commit 6fde896d77
3 changed files with 133 additions and 14 deletions

View File

@ -40,6 +40,41 @@ def output_spy():
return _OutputSpy() return _OutputSpy()
class _VirtualGlobalSocket(object):
def __init__(self, vsocket):
self.vsocket = vsocket
self.addrinfodata = {}
# pylint: disable=unused-argument
def create_connection(self, address, timeout=0, source_address=None):
# pylint: disable=protected-access
return self.vsocket._connect(address, True)
# pylint: disable=unused-argument
def socket(self,
family=socket.AF_INET,
socktype=socket.SOCK_STREAM,
proto=0,
fileno=None):
return self.vsocket
def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
key = '{0}#{1}'.format(host, port)
if key in self.addrinfodata:
data = self.addrinfodata[key]
if isinstance(data, Exception):
raise data
return data
if host == 'localhost':
r = []
if family in (0, socket.AF_INET):
r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port)))
if family in (0, socket.AF_INET6):
r.append((socket.AF_INET6, 1, 6, '', ('::1', port)))
return r
return []
class _VirtualSocket(object): class _VirtualSocket(object):
def __init__(self): def __init__(self):
self.sock_address = ('127.0.0.1', 0) self.sock_address = ('127.0.0.1', 0)
@ -49,6 +84,7 @@ class _VirtualSocket(object):
self.rdata = [] self.rdata = []
self.sdata = [] self.sdata = []
self.errors = {} self.errors = {}
self.gsock = _VirtualGlobalSocket(self)
def _check_err(self, method): def _check_err(self, method):
method_error = self.errors.get(method) method_error = self.errors.get(method)
@ -113,18 +149,8 @@ class _VirtualSocket(object):
@pytest.fixture() @pytest.fixture()
def virtual_socket(monkeypatch): def virtual_socket(monkeypatch):
vsocket = _VirtualSocket() vsocket = _VirtualSocket()
gsock = vsocket.gsock
# pylint: disable=unused-argument monkeypatch.setattr(socket, 'create_connection', gsock.create_connection)
def _socket(family=socket.AF_INET, monkeypatch.setattr(socket, 'socket', gsock.socket)
socktype=socket.SOCK_STREAM, monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo)
proto=0,
fileno=None):
return vsocket
def _cc(address, timeout=0, source_address=None):
# pylint: disable=protected-access
return vsocket._connect(address, True)
monkeypatch.setattr(socket, 'create_connection', _cc)
monkeypatch.setattr(socket, 'socket', _socket)
return vsocket return vsocket

View File

@ -30,6 +30,13 @@ class TestErrors(object):
lines = spy.flush() lines = spy.flush()
return lines return lines
def test_connection_unresolved(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = []
lines = self._audit(output_spy)
assert len(lines) == 1
assert 'has no DNS records' in lines[-1]
def test_connection_refused(self, output_spy, virtual_socket): def test_connection_refused(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.errors['connect'] = socket.error(errno.ECONNREFUSED, 'Connection refused') vsocket.errors['connect'] = socket.error(errno.ECONNREFUSED, 'Connection refused')
@ -91,6 +98,7 @@ class TestErrors(object):
def test_connection_closed_after_header(self, output_spy, virtual_socket): def test_connection_closed_after_header(self, output_spy, virtual_socket):
vsocket = virtual_socket vsocket = virtual_socket
vsocket.rdata.append(b'header line 1\n') vsocket.rdata.append(b'header line 1\n')
vsocket.rdata.append(b'\n')
vsocket.rdata.append(b'header line 2\n') vsocket.rdata.append(b'header line 2\n')
vsocket.rdata.append(socket.error(errno.ECONNRESET, 'Connection reset by peer')) vsocket.rdata.append(socket.error(errno.ECONNRESET, 'Connection reset by peer'))
lines = self._audit(output_spy) lines = self._audit(output_spy)

85
test/test_resolve.py Normal file
View File

@ -0,0 +1,85 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import socket
import pytest
# pylint: disable=attribute-defined-outside-init,protected-access
class TestResolve(object):
@pytest.fixture(autouse=True)
def init(self, ssh_audit):
self.AuditConf = ssh_audit.AuditConf
self.audit = ssh_audit.audit
self.ssh = ssh_audit.SSH
def _conf(self):
conf = self.AuditConf('localhost', 22)
conf.colors = False
conf.batch = True
return conf
def test_resolve_error(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known')
s = self.ssh.Socket('localhost', 22)
conf = self._conf()
output_spy.begin()
with pytest.raises(SystemExit):
r = list(s._resolve(conf.ipvo))
lines = output_spy.flush()
assert len(lines) == 1
assert 'hostname nor servname provided' in lines[-1]
def test_resolve_hostname_without_records(self, output_spy, virtual_socket):
vsocket = virtual_socket
vsocket.gsock.addrinfodata['localhost#22'] = []
s = self.ssh.Socket('localhost', 22)
conf = self._conf()
output_spy.begin()
r = list(s._resolve(conf.ipvo))
assert len(r) == 0
def test_resolve_ipv4(self, virtual_socket):
vsocket = virtual_socket
conf = self._conf()
conf.ipv4 = True
s = self.ssh.Socket('localhost', 22)
r = list(s._resolve(conf.ipvo))
assert len(r) == 1
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
def test_resolve_ipv6(self, virtual_socket):
vsocket = virtual_socket
s = self.ssh.Socket('localhost', 22)
conf = self._conf()
conf.ipv6 = True
r = list(s._resolve(conf.ipvo))
assert len(r) == 1
assert r[0] == (socket.AF_INET6, ('::1', 22))
def test_resolve_ipv46_both(self, virtual_socket):
vsocket = virtual_socket
s = self.ssh.Socket('localhost', 22)
conf = self._conf()
r = list(s._resolve(conf.ipvo))
assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
assert r[1] == (socket.AF_INET6, ('::1', 22))
def test_resolve_ipv46_order(self, virtual_socket):
vsocket = virtual_socket
s = self.ssh.Socket('localhost', 22)
conf = self._conf()
conf.ipv4 = True
conf.ipv6 = True
r = list(s._resolve(conf.ipvo))
assert len(r) == 2
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
assert r[1] == (socket.AF_INET6, ('::1', 22))
conf = self._conf()
conf.ipv6 = True
conf.ipv4 = True
r = list(s._resolve(conf.ipvo))
assert len(r) == 2
assert r[0] == (socket.AF_INET6, ('::1', 22))
assert r[1] == (socket.AF_INET, ('127.0.0.1', 22))