Multiple style fixes (protector, veryhigh).

This commit is contained in:
Andris Raugulis 2016-09-02 17:22:00 +03:00
parent c759e53779
commit fba6397721
1 changed files with 44 additions and 16 deletions

View File

@ -29,7 +29,8 @@ import os, io, sys, socket, struct, random, errno, getopt
VERSION = 'v1.0.20160902' VERSION = 'v1.0.20160902'
SSH_BANNER = 'SSH-2.0-OpenSSH_7.3' SSH_BANNER = 'SSH-2.0-OpenSSH_7.3'
def usage(err = None):
def usage(err=None):
p = os.path.basename(sys.argv[0]) p = os.path.basename(sys.argv[0])
out.batch = False out.batch = False
out.minlevel = 'info' out.minlevel = 'info'
@ -45,6 +46,7 @@ def usage(err = None):
out.sep() out.sep()
sys.exit(1) sys.exit(1)
class Output(object): class Output(object):
LEVELS = ['info', 'warn', 'fail'] LEVELS = ['info', 'warn', 'fail']
COLORS = {'head': 36, 'good': 32, 'warn': 33, 'fail': 31} COLORS = {'head': 36, 'good': 32, 'warn': 33, 'fail': 31}
@ -58,20 +60,24 @@ class Output(object):
@property @property
def minlevel(self): def minlevel(self):
return self.__minlevel return self.__minlevel
@minlevel.setter @minlevel.setter
def minlevel(self, name): def minlevel(self, name):
self.__minlevel = self.getlevel(name) self.__minlevel = self.getlevel(name)
def getlevel(self, name): def getlevel(self, name):
cname = 'info' if name == 'good' else name cname = 'info' if name == 'good' else name
if not cname in self.LEVELS: if cname not in self.LEVELS:
return sys.maxsize return sys.maxsize
return self.LEVELS.index(cname) return self.LEVELS.index(cname)
def sep(self): def sep(self):
if not self.batch: if not self.batch:
print() print()
def _colorized(self, color): def _colorized(self, color):
return lambda x: print(u'{0}{1}\033[0m'.format(color, x)) return lambda x: print(u'{0}{1}\033[0m'.format(color, x))
def __getattr__(self, name): def __getattr__(self, name):
if name == 'head' and self.batch: if name == 'head' and self.batch:
return lambda x: None return lambda x: None
@ -83,25 +89,30 @@ class Output(object):
else: else:
return lambda x: print(u'{0}'.format(x)) return lambda x: print(u'{0}'.format(x))
class OutputBuffer(list): class OutputBuffer(list):
def __enter__(self): def __enter__(self):
self.__buf = io.StringIO() self.__buf = io.StringIO()
self.__stdout = sys.stdout self.__stdout = sys.stdout
sys.stdout = self.__buf sys.stdout = self.__buf
return self return self
def flush(self): def flush(self):
for line in self: for line in self:
print(line) print(line)
def __exit__(self, *args): def __exit__(self, *args):
self.extend(self.__buf.getvalue().splitlines()) self.extend(self.__buf.getvalue().splitlines())
sys.stdout = self.__stdout sys.stdout = self.__stdout
class KexParty(object): class KexParty(object):
encryption = [] encryption = []
mac = [] mac = []
compression = [] compression = []
languages = [] languages = []
class Kex(object): class Kex(object):
cookie = None cookie = None
kex_algorithms = [] kex_algorithms = []
@ -130,8 +141,9 @@ class Kex(object):
kex.unused = buf.read_int() kex.unused = buf.read_int()
return kex return kex
class ReadBuf(object): class ReadBuf(object):
def __init__(self, data = None): def __init__(self, data=None):
super(ReadBuf, self).__init__() super(ReadBuf, self).__init__()
self._buf = io.BytesIO(data) if data else io.BytesIO() self._buf = io.BytesIO(data) if data else io.BytesIO()
self._len = len(data) if data else 0 self._len = len(data) if data else 0
@ -156,8 +168,9 @@ class ReadBuf(object):
list_size = self.read_int() list_size = self.read_int()
return self.read(list_size).decode().split(',') return self.read(list_size).decode().split(',')
class WriteBuf(object): class WriteBuf(object):
def __init__(self, data = None): def __init__(self, data=None):
super(WriteBuf, self).__init__() super(WriteBuf, self).__init__()
self._wbuf = io.BytesIO(data) if data else io.BytesIO() self._wbuf = io.BytesIO(data) if data else io.BytesIO()
@ -185,10 +198,11 @@ class WriteBuf(object):
def write_mpint(self, v): def write_mpint(self, v):
length = v.bit_length() // 8 + 1 length = v.bit_length() // 8 + 1
data = [(v >> i*8) & 0xff for i in reversed(range(length))] data = [(v >> i * 8) & 0xff for i in reversed(range(length))]
data = bytes(bytearray(data)) data = bytes(bytearray(data))
self.write_string(data) self.write_string(data)
class SSH(object): class SSH(object):
MSG_KEXINIT = 20 MSG_KEXINIT = 20
MSG_NEWKEYS = 21 MSG_NEWKEYS = 21
@ -198,7 +212,7 @@ class SSH(object):
class Socket(ReadBuf, WriteBuf): class Socket(ReadBuf, WriteBuf):
SM_BANNER_SENT = 1 SM_BANNER_SENT = 1
def __init__(self, host, port, cto = 3.0, rto = 5.0): def __init__(self, host, port, cto=3.0, rto=5.0):
self.__block_size = 8 self.__block_size = 8
self.__state = 0 self.__state = 0
self.__header = [] self.__header = []
@ -231,7 +245,7 @@ class SSH(object):
self.__header.append(line) self.__header.append(line)
return self.__banner, self.__header return self.__banner, self.__header
def recv(self, size = 2048): def recv(self, size=2048):
try: try:
data = self.__sock.recv(size) data = self.__sock.recv(size)
except socket.timeout as e: except socket.timeout as e:
@ -257,7 +271,7 @@ class SSH(object):
return (-1, e) return (-1, e)
self.__sock.send(data) self.__sock.send(data)
def send_banner(self, banner = SSH_BANNER): def send_banner(self, banner=SSH_BANNER):
self.send(banner.encode() + b'\r\n') self.send(banner.encode() + b'\r\n')
if self.__state < self.SM_BANNER_SENT: if self.__state < self.SM_BANNER_SENT:
self.__state = self.SM_BANNER_SENT self.__state = self.SM_BANNER_SENT
@ -314,6 +328,7 @@ class SSH(object):
except: except:
pass pass
class KexDH(object): class KexDH(object):
def __init__(self, alg, g, p): def __init__(self, alg, g, p):
self.__alg = alg self.__alg = alg
@ -330,6 +345,7 @@ class KexDH(object):
s.write_mpint(self.__e) s.write_mpint(self.__e)
s.send_packet() s.send_packet()
class KexGroup1(KexDH): class KexGroup1(KexDH):
def __init__(self): def __init__(self):
# rfc2409: second oakley group # rfc2409: second oakley group
@ -340,6 +356,7 @@ class KexGroup1(KexDH):
'ffffffffffffffff', 16) 'ffffffffffffffff', 16)
super(KexGroup1, self).__init__('sha1', 2, p) super(KexGroup1, self).__init__('sha1', 2, p)
class KexGroup14(KexDH): class KexGroup14(KexDH):
def __init__(self): def __init__(self):
# rfc3526: 2048-bit modp group # rfc3526: 2048-bit modp group
@ -354,12 +371,14 @@ class KexGroup14(KexDH):
'15728e5a8aacaa68ffffffffffffffff', 16) '15728e5a8aacaa68ffffffffffffffff', 16)
super(KexGroup14, self).__init__('sha1', 2, p) super(KexGroup14, self).__init__('sha1', 2, p)
def get_ssh_version(version_desc): def get_ssh_version(version_desc):
if version_desc.startswith('d'): if version_desc.startswith('d'):
return ('Dropbear SSH', version_desc[1:]) return ('Dropbear SSH', version_desc[1:])
else: else:
return ('OpenSSH', version_desc) return ('OpenSSH', version_desc)
def get_alg_since_text(alg_desc): def get_alg_since_text(alg_desc):
tv = [] tv = []
versions = alg_desc[0] versions = alg_desc[0]
@ -368,7 +387,8 @@ def get_alg_since_text(alg_desc):
tv.append('{0} {1}'.format(ssh_prefix, ssh_version)) tv.append('{0} {1}'.format(ssh_prefix, ssh_version))
return 'available since ' + ', '.join(tv).rstrip(', ') return 'available since ' + ', '.join(tv).rstrip(', ')
def get_alg_timeframe(alg_desc, result = {}):
def get_alg_timeframe(alg_desc, result={}):
versions = alg_desc[0] versions = alg_desc[0]
vlen = len(versions) vlen = len(versions)
for i in range(3): for i in range(3):
@ -383,7 +403,7 @@ def get_alg_timeframe(alg_desc, result = {}):
continue continue
for v in cversions.split(','): for v in cversions.split(','):
ssh_prefix, ssh_version = get_ssh_version(v) ssh_prefix, ssh_version = get_ssh_version(v)
if not ssh_prefix in result: if ssh_prefix not in result:
result[ssh_prefix] = [None, None, None] result[ssh_prefix] = [None, None, None]
prev, push = result[ssh_prefix][i], False prev, push = result[ssh_prefix][i], False
if prev is None: if prev is None:
@ -396,6 +416,7 @@ def get_alg_timeframe(alg_desc, result = {}):
result[ssh_prefix][i] = ssh_version result[ssh_prefix][i] = ssh_version
return result return result
def get_ssh_timeframe(kex): def get_ssh_timeframe(kex):
alg_timeframe = {} alg_timeframe = {}
algs = {'kex': kex.kex_algorithms, algs = {'kex': kex.kex_algorithms,
@ -520,6 +541,7 @@ KEX_DB = {
} }
} }
def output_algorithms(title, alg_type, algorithms, maxlen=0): def output_algorithms(title, alg_type, algorithms, maxlen=0):
with OutputBuffer() as obuf: with OutputBuffer() as obuf:
for algorithm in algorithms: for algorithm in algorithms:
@ -529,6 +551,7 @@ def output_algorithms(title, alg_type, algorithms, maxlen=0):
obuf.flush() obuf.flush()
out.sep() out.sep()
def output_algorithm(alg_type, alg_name, alg_max_len=0): def output_algorithm(alg_type, alg_name, alg_max_len=0):
prefix = '(' + alg_type + ') ' prefix = '(' + alg_type + ') '
if alg_max_len == 0: if alg_max_len == 0:
@ -556,7 +579,7 @@ def output_algorithm(alg_type, alg_name, alg_max_len=0):
if first: if first:
if first and level == 'info': if first and level == 'info':
f = out.good f = out.good
f(prefix + alg_name + padding +' -- ' + text) f(prefix + alg_name + padding + ' -- ' + text)
first = False first = False
else: else:
if out.verbose: if out.verbose:
@ -564,6 +587,7 @@ def output_algorithm(alg_type, alg_name, alg_max_len=0):
else: else:
f(' ' * len(prefix + alg_name) + padding + ' `- ' + text) f(' ' * len(prefix + alg_name) + padding + ' `- ' + text)
def output_compatibility(kex, client=False): def output_compatibility(kex, client=False):
ssh_timeframe = get_ssh_timeframe(kex) ssh_timeframe = get_ssh_timeframe(kex)
cp = 2 if client else 1 cp = 2 if client else 1
@ -582,6 +606,7 @@ def output_compatibility(kex, client=False):
if len(comp_text) > 0: if len(comp_text) > 0:
out.good('(gen) compatibility: ' + ', '.join(comp_text)) out.good('(gen) compatibility: ' + ', '.join(comp_text))
def output(banner, header, kex): def output(banner, header, kex):
with OutputBuffer() as obuf: with OutputBuffer() as obuf:
if len(header) > 0: if len(header) > 0:
@ -625,6 +650,7 @@ def parse_int(v):
except: except:
return 0 return 0
def parse_args(): def parse_args():
host, port = None, 22 host, port = None, 22
try: try:
@ -657,6 +683,7 @@ def parse_args():
usage('port {0} is not valid'.format(port)) usage('port {0} is not valid'.format(port))
return host, port return host, port
def main(): def main():
host, port = parse_args() host, port = parse_args()
s = SSH.Socket(host, port) s = SSH.Socket(host, port)
@ -678,6 +705,7 @@ def main():
kex = Kex.parse(payload) kex = Kex.parse(payload)
output(banner, header, kex) output(banner, header, kex)
if __name__ == '__main__': if __name__ == '__main__':
out = Output() out = Output()
main() main()