mirror of
				https://github.com/jtesta/ssh-audit.git
				synced 2025-10-30 21:15:27 +01:00 
			
		
		
		
	Add static type checks via mypy (optional static type checker), Add relevant tests, which could trigger the issue.
This commit is contained in:
		
							
								
								
									
										20
									
								
								ssh-audit.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								ssh-audit.py
									
									
									
									
									
								
							| @@ -25,6 +25,10 @@ | |||||||
| """ | """ | ||||||
| from __future__ import print_function | from __future__ import print_function | ||||||
| import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 | import os, io, sys, socket, struct, random, errno, getopt, re, hashlib, base64 | ||||||
|  | try: | ||||||
|  | 	from typing import List, Tuple, Text | ||||||
|  | except: | ||||||
|  | 	pass | ||||||
|  |  | ||||||
| VERSION = 'v1.6.1.dev' | VERSION = 'v1.6.1.dev' | ||||||
|  |  | ||||||
| @@ -940,14 +944,15 @@ class SSH(object): | |||||||
| 			return self.__banner, self.__header | 			return self.__banner, self.__header | ||||||
| 		 | 		 | ||||||
| 		def recv(self, size=2048): | 		def recv(self, size=2048): | ||||||
|  | 			# type: (int) -> Tuple[int, str] | ||||||
| 			try: | 			try: | ||||||
| 				data = self.__sock.recv(size) | 				data = self.__sock.recv(size) | ||||||
| 			except socket.timeout as e: | 			except socket.timeout: | ||||||
| 				r = 0 if e.strerror == 'timed out' else -1 | 				return (-1, 'timeout') | ||||||
| 				return (r, e) |  | ||||||
| 			except socket.error as e: | 			except socket.error as e: | ||||||
| 				r = 0 if e.errno in (errno.EAGAIN, errno.EWOULDBLOCK) else -1 | 				if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK): | ||||||
| 				return (r, e) | 					return (0, 'retry') | ||||||
|  | 				return (-1, str(e.args[-1])) | ||||||
| 			if len(data) == 0: | 			if len(data) == 0: | ||||||
| 				return (-1, None) | 				return (-1, None) | ||||||
| 			pos = self._buf.tell() | 			pos = self._buf.tell() | ||||||
| @@ -977,6 +982,7 @@ class SSH(object): | |||||||
| 					raise SSH.Socket.InsufficientReadException(e) | 					raise SSH.Socket.InsufficientReadException(e) | ||||||
| 		 | 		 | ||||||
| 		def read_packet(self, sshv=2): | 		def read_packet(self, sshv=2): | ||||||
|  | 			# type: (int) -> Tuple[int, bytes] | ||||||
| 			try: | 			try: | ||||||
| 				header = WriteBuf() | 				header = WriteBuf() | ||||||
| 				self.ensure_read(4) | 				self.ensure_read(4) | ||||||
| @@ -1024,7 +1030,7 @@ class SSH(object): | |||||||
| 					header.write(self.read(self.unread_len)) | 					header.write(self.read(self.unread_len)) | ||||||
| 					e = header.write_flush().strip() | 					e = header.write_flush().strip() | ||||||
| 				else: | 				else: | ||||||
| 					e = ex.args[0] | 					e = ex.args[0].encode('utf-8') | ||||||
| 				return (-1, e) | 				return (-1, e) | ||||||
| 		 | 		 | ||||||
| 		def send_packet(self): | 		def send_packet(self): | ||||||
| @@ -1651,7 +1657,7 @@ def audit(conf, sshv=None): | |||||||
| 	if err is None: | 	if err is None: | ||||||
| 		packet_type, payload = s.read_packet(sshv) | 		packet_type, payload = s.read_packet(sshv) | ||||||
| 		if packet_type < 0: | 		if packet_type < 0: | ||||||
| 			payload = str(payload).decode('utf-8') | 			payload = payload.decode('utf-8') if payload else u'empty' | ||||||
| 			if payload == u'Protocol major versions differ.': | 			if payload == u'Protocol major versions differ.': | ||||||
| 				if sshv == 2 and conf.ssh1: | 				if sshv == 2 and conf.ssh1: | ||||||
| 					audit(conf, 1) | 					audit(conf, 1) | ||||||
|   | |||||||
							
								
								
									
										96
									
								
								test/test_errors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										96
									
								
								test/test_errors.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,96 @@ | |||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | import pytest, socket | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestErrors(object): | ||||||
|  | 	@pytest.fixture(autouse=True) | ||||||
|  | 	def init(self, ssh_audit): | ||||||
|  | 		self.AuditConf = ssh_audit.AuditConf | ||||||
|  | 		self.audit = ssh_audit.audit | ||||||
|  | 	 | ||||||
|  | 	def _conf(self): | ||||||
|  | 		conf = self.AuditConf('localhost', 22) | ||||||
|  | 		conf.batch = True | ||||||
|  | 		return conf | ||||||
|  | 	 | ||||||
|  | 	def test_connection_refused(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.errors['connect'] = socket.error(61, 'Connection refused') | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 1 | ||||||
|  | 		assert 'Connection refused' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_connection_closed_before_banner(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 1 | ||||||
|  | 		assert 'did not receive banner' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_connection_closed_after_header(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(b'header line 1\n') | ||||||
|  | 		vsocket.rdata.append(b'header line 2\n') | ||||||
|  | 		vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 3 | ||||||
|  | 		assert 'did not receive banner' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_connection_closed_after_banner(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') | ||||||
|  | 		vsocket.rdata.append(socket.error(54, 'Connection reset by peer')) | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 2 | ||||||
|  | 		assert 'error reading packet' in lines[-1] | ||||||
|  | 		assert 'reset by peer' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_empty_data_after_banner(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 2 | ||||||
|  | 		assert 'error reading packet' in lines[-1] | ||||||
|  | 		assert 'empty' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_wrong_data_after_banner(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(b'SSH-2.0-ssh-audit-test\r\n') | ||||||
|  | 		vsocket.rdata.append(b'xxx\n') | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			self.audit(self._conf()) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 2 | ||||||
|  | 		assert 'error reading packet' in lines[-1] | ||||||
|  | 		assert 'xxx' in lines[-1] | ||||||
|  | 	 | ||||||
|  | 	def test_protocol_mismatch_by_conf(self, output_spy, virtual_socket): | ||||||
|  | 		vsocket = virtual_socket | ||||||
|  | 		vsocket.rdata.append(b'SSH-1.3-ssh-audit-test\r\n') | ||||||
|  | 		vsocket.rdata.append(b'Protocol major versions differ.\n') | ||||||
|  | 		output_spy.begin() | ||||||
|  | 		with pytest.raises(SystemExit): | ||||||
|  | 			conf = self._conf() | ||||||
|  | 			conf.ssh1, conf.ssh2 = True, False | ||||||
|  | 			self.audit(conf) | ||||||
|  | 		lines = output_spy.flush() | ||||||
|  | 		assert len(lines) == 3 | ||||||
|  | 		assert 'error reading packet' in lines[-1] | ||||||
|  | 		assert 'major versions differ' in lines[-1] | ||||||
		Reference in New Issue
	
	Block a user
	 Andris Raugulis
					Andris Raugulis