mirror of https://github.com/jtesta/ssh-audit.git
Add resolve tests.
This commit is contained in:
parent
6c4b9fcadf
commit
6fde896d77
|
@ -40,6 +40,41 @@ def output_spy():
|
||||||
return _OutputSpy()
|
return _OutputSpy()
|
||||||
|
|
||||||
|
|
||||||
|
class _VirtualGlobalSocket(object):
|
||||||
|
def __init__(self, vsocket):
|
||||||
|
self.vsocket = vsocket
|
||||||
|
self.addrinfodata = {}
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def create_connection(self, address, timeout=0, source_address=None):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
return self.vsocket._connect(address, True)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def socket(self,
|
||||||
|
family=socket.AF_INET,
|
||||||
|
socktype=socket.SOCK_STREAM,
|
||||||
|
proto=0,
|
||||||
|
fileno=None):
|
||||||
|
return self.vsocket
|
||||||
|
|
||||||
|
def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
|
||||||
|
key = '{0}#{1}'.format(host, port)
|
||||||
|
if key in self.addrinfodata:
|
||||||
|
data = self.addrinfodata[key]
|
||||||
|
if isinstance(data, Exception):
|
||||||
|
raise data
|
||||||
|
return data
|
||||||
|
if host == 'localhost':
|
||||||
|
r = []
|
||||||
|
if family in (0, socket.AF_INET):
|
||||||
|
r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port)))
|
||||||
|
if family in (0, socket.AF_INET6):
|
||||||
|
r.append((socket.AF_INET6, 1, 6, '', ('::1', port)))
|
||||||
|
return r
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class _VirtualSocket(object):
|
class _VirtualSocket(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sock_address = ('127.0.0.1', 0)
|
self.sock_address = ('127.0.0.1', 0)
|
||||||
|
@ -49,6 +84,7 @@ class _VirtualSocket(object):
|
||||||
self.rdata = []
|
self.rdata = []
|
||||||
self.sdata = []
|
self.sdata = []
|
||||||
self.errors = {}
|
self.errors = {}
|
||||||
|
self.gsock = _VirtualGlobalSocket(self)
|
||||||
|
|
||||||
def _check_err(self, method):
|
def _check_err(self, method):
|
||||||
method_error = self.errors.get(method)
|
method_error = self.errors.get(method)
|
||||||
|
@ -113,18 +149,8 @@ class _VirtualSocket(object):
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def virtual_socket(monkeypatch):
|
def virtual_socket(monkeypatch):
|
||||||
vsocket = _VirtualSocket()
|
vsocket = _VirtualSocket()
|
||||||
|
gsock = vsocket.gsock
|
||||||
# pylint: disable=unused-argument
|
monkeypatch.setattr(socket, 'create_connection', gsock.create_connection)
|
||||||
def _socket(family=socket.AF_INET,
|
monkeypatch.setattr(socket, 'socket', gsock.socket)
|
||||||
socktype=socket.SOCK_STREAM,
|
monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo)
|
||||||
proto=0,
|
|
||||||
fileno=None):
|
|
||||||
return vsocket
|
|
||||||
|
|
||||||
def _cc(address, timeout=0, source_address=None):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
return vsocket._connect(address, True)
|
|
||||||
|
|
||||||
monkeypatch.setattr(socket, 'create_connection', _cc)
|
|
||||||
monkeypatch.setattr(socket, 'socket', _socket)
|
|
||||||
return vsocket
|
return vsocket
|
||||||
|
|
|
@ -30,6 +30,13 @@ class TestErrors(object):
|
||||||
lines = spy.flush()
|
lines = spy.flush()
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
def test_connection_unresolved(self, output_spy, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
vsocket.gsock.addrinfodata['localhost#22'] = []
|
||||||
|
lines = self._audit(output_spy)
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert 'has no DNS records' in lines[-1]
|
||||||
|
|
||||||
def test_connection_refused(self, output_spy, virtual_socket):
|
def test_connection_refused(self, output_spy, virtual_socket):
|
||||||
vsocket = virtual_socket
|
vsocket = virtual_socket
|
||||||
vsocket.errors['connect'] = socket.error(errno.ECONNREFUSED, 'Connection refused')
|
vsocket.errors['connect'] = socket.error(errno.ECONNREFUSED, 'Connection refused')
|
||||||
|
@ -91,6 +98,7 @@ class TestErrors(object):
|
||||||
def test_connection_closed_after_header(self, output_spy, virtual_socket):
|
def test_connection_closed_after_header(self, output_spy, virtual_socket):
|
||||||
vsocket = virtual_socket
|
vsocket = virtual_socket
|
||||||
vsocket.rdata.append(b'header line 1\n')
|
vsocket.rdata.append(b'header line 1\n')
|
||||||
|
vsocket.rdata.append(b'\n')
|
||||||
vsocket.rdata.append(b'header line 2\n')
|
vsocket.rdata.append(b'header line 2\n')
|
||||||
vsocket.rdata.append(socket.error(errno.ECONNRESET, 'Connection reset by peer'))
|
vsocket.rdata.append(socket.error(errno.ECONNRESET, 'Connection reset by peer'))
|
||||||
lines = self._audit(output_spy)
|
lines = self._audit(output_spy)
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import socket
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=attribute-defined-outside-init,protected-access
|
||||||
|
class TestResolve(object):
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def init(self, ssh_audit):
|
||||||
|
self.AuditConf = ssh_audit.AuditConf
|
||||||
|
self.audit = ssh_audit.audit
|
||||||
|
self.ssh = ssh_audit.SSH
|
||||||
|
|
||||||
|
def _conf(self):
|
||||||
|
conf = self.AuditConf('localhost', 22)
|
||||||
|
conf.colors = False
|
||||||
|
conf.batch = True
|
||||||
|
return conf
|
||||||
|
|
||||||
|
def test_resolve_error(self, output_spy, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
vsocket.gsock.addrinfodata['localhost#22'] = socket.gaierror(8, 'hostname nor servname provided, or not known')
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
conf = self._conf()
|
||||||
|
output_spy.begin()
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
lines = output_spy.flush()
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert 'hostname nor servname provided' in lines[-1]
|
||||||
|
|
||||||
|
def test_resolve_hostname_without_records(self, output_spy, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
vsocket.gsock.addrinfodata['localhost#22'] = []
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
conf = self._conf()
|
||||||
|
output_spy.begin()
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 0
|
||||||
|
|
||||||
|
def test_resolve_ipv4(self, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
conf = self._conf()
|
||||||
|
conf.ipv4 = True
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 1
|
||||||
|
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
|
||||||
|
|
||||||
|
def test_resolve_ipv6(self, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
conf = self._conf()
|
||||||
|
conf.ipv6 = True
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 1
|
||||||
|
assert r[0] == (socket.AF_INET6, ('::1', 22))
|
||||||
|
|
||||||
|
def test_resolve_ipv46_both(self, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
conf = self._conf()
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 2
|
||||||
|
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
|
||||||
|
assert r[1] == (socket.AF_INET6, ('::1', 22))
|
||||||
|
|
||||||
|
def test_resolve_ipv46_order(self, virtual_socket):
|
||||||
|
vsocket = virtual_socket
|
||||||
|
s = self.ssh.Socket('localhost', 22)
|
||||||
|
conf = self._conf()
|
||||||
|
conf.ipv4 = True
|
||||||
|
conf.ipv6 = True
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 2
|
||||||
|
assert r[0] == (socket.AF_INET, ('127.0.0.1', 22))
|
||||||
|
assert r[1] == (socket.AF_INET6, ('::1', 22))
|
||||||
|
conf = self._conf()
|
||||||
|
conf.ipv6 = True
|
||||||
|
conf.ipv4 = True
|
||||||
|
r = list(s._resolve(conf.ipvo))
|
||||||
|
assert len(r) == 2
|
||||||
|
assert r[0] == (socket.AF_INET6, ('::1', 22))
|
||||||
|
assert r[1] == (socket.AF_INET, ('127.0.0.1', 22))
|
Loading…
Reference in New Issue