Merge branch 'timeout_arg' into all_my_patches

This commit is contained in:
Joe Testa 2017-10-31 16:36:20 -04:00
commit a9f6b93391
1 changed files with 24 additions and 8 deletions

View File

@ -68,6 +68,7 @@ def usage(err=None):
uout.info(' -n, --no-colors disable colors') uout.info(' -n, --no-colors disable colors')
uout.info(' -v, --verbose verbose output') uout.info(' -v, --verbose verbose output')
uout.info(' -l, --level=<level> minimum output level (info|warn|fail)') uout.info(' -l, --level=<level> minimum output level (info|warn|fail)')
uout.info(' -t, --timeout=<secs> timeout (in seconds) for connection and reading\n (default: 5)')
uout.sep() uout.sep()
sys.exit(1) sys.exit(1)
@ -87,6 +88,7 @@ class AuditConf(object):
self.ipvo = () # type: Sequence[int] self.ipvo = () # type: Sequence[int]
self.ipv4 = False self.ipv4 = False
self.ipv6 = False self.ipv6 = False
self.timeout = 5.0
def __setattr__(self, name, value): def __setattr__(self, name, value):
# type: (str, Union[str, int, bool, Sequence[int]]) -> None # type: (str, Union[str, int, bool, Sequence[int]]) -> None
@ -124,6 +126,11 @@ class AuditConf(object):
valid = True valid = True
elif name == 'host': elif name == 'host':
valid = True valid = True
elif name == 'timeout':
value = utils.parse_float(value)
if value == -1.0:
raise ValueError('invalid timeout: {0}'.format(value))
valid = True
if valid: if valid:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
@ -133,9 +140,9 @@ class AuditConf(object):
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
aconf = cls() aconf = cls()
try: try:
sopts = 'h1246p:bnvl:' sopts = 'h1246p:bnvl:t:'
lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port', lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port',
'batch', 'no-colors', 'verbose', 'level='] 'batch', 'no-colors', 'verbose', 'level=', 'timeout=']
opts, args = getopt.getopt(args, sopts, lopts) opts, args = getopt.getopt(args, sopts, lopts)
except getopt.GetoptError as err: except getopt.GetoptError as err:
usage_cb(str(err)) usage_cb(str(err))
@ -165,6 +172,8 @@ class AuditConf(object):
if a not in ('info', 'warn', 'fail'): if a not in ('info', 'warn', 'fail'):
usage_cb('level {0} is not valid'.format(a)) usage_cb('level {0} is not valid'.format(a))
aconf.level = a aconf.level = a
elif o in ('-t', '--timeout'):
aconf.timeout = float(a)
if len(args) == 0: if len(args) == 0:
usage_cb() usage_cb()
if oport is not None: if oport is not None:
@ -1600,16 +1609,15 @@ class SSH(object): # pylint: disable=too-few-public-methods
out.fail('[exception] {0}'.format(e)) out.fail('[exception] {0}'.format(e))
sys.exit(1) sys.exit(1)
def connect(self, ipvo=(), cto=3.0, rto=5.0): def connect(self, ipvo, timeout):
# type: (Sequence[int], float, float) -> None # type: (Sequence[int], float) -> None
err = None err = None
for af, addr in self._resolve(ipvo): for af, addr in self._resolve(ipvo):
s = None s = None
try: try:
s = socket.socket(af, socket.SOCK_STREAM) s = socket.socket(af, socket.SOCK_STREAM)
s.settimeout(cto) s.settimeout(timeout)
s.connect(addr) s.connect(addr)
s.settimeout(rto)
self.__sock = s self.__sock = s
return return
except socket.error as e: except socket.error as e:
@ -2174,6 +2182,14 @@ class Utils(object):
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
return 0 return 0
@staticmethod
def parse_float(v):
# type: (Any) -> float
try:
return float(v)
except: # pylint: disable=bare-except
return -1.0
def audit(aconf, sshv=None): def audit(aconf, sshv=None):
# type: (AuditConf, Optional[int]) -> None # type: (AuditConf, Optional[int]) -> None
@ -2182,7 +2198,7 @@ def audit(aconf, sshv=None):
out.level = aconf.level out.level = aconf.level
out.use_colors = aconf.colors out.use_colors = aconf.colors
s = SSH.Socket(aconf.host, aconf.port) s = SSH.Socket(aconf.host, aconf.port)
s.connect(aconf.ipvo) s.connect(aconf.ipvo, aconf.timeout)
if sshv is None: if sshv is None:
sshv = 2 if aconf.ssh2 else 1 sshv = 2 if aconf.ssh2 else 1
err = None err = None