diff --git a/storm/__main__.py b/storm/__main__.py index 7f0b323..905f1de 100644 --- a/storm/__main__.py +++ b/storm/__main__.py @@ -180,8 +180,84 @@ def delete(name, config=None): except ValueError as error: print(get_formatted_message(error, 'error'), file=sys.stderr) +def format_host(host,defaults,with_tags=False,compact_tags=False): + result = "" + result_stack = "" + if host.get("type") == 'entry': + if not host.get("host") == "*": + result += " {0} -> {1}@{2}:{3}".format( + colored(host["host"], 'green', attrs=["bold", ]), + host.get("options").get( + "user", get_default("user", defaults) + ), + host.get("options").get( + "hostname", "[hostname_not_specified]" + ), + host.get("options").get( + "port", get_default("port", defaults) + ) + ) + extra = False + for key, value in six.iteritems(host.get("options")): + + if not key in ["user", "hostname", "port"]: + if not extra: + custom_options = colored( + '\n\t[custom options] ', 'white' + ) + result += " {0}".format(custom_options) + extra = True + + if isinstance(value, collections.Sequence): + if isinstance(value, builtins.list): + value = ",".join(value) + + result += "{0}={1} ".format(key, value) + if extra: + result = result[0:-1] + + if with_tags: + if len(host.get('tags')) > 0: + if compact_tags: + result += " {0}".format('') + value = " ".join( + map(lambda tag: colored(tag,'green', attrs=["bold", ]), + host.get('tags') + ) + ) + result += "{0} ".format(value) + else: + tags = colored( + '\n\t[tags] ', 'white' + ) + result += " {0}".format(tags) + value = ", ".join(host.get('tags')) + result += "{0} ".format(value) + + result += "\n\n" + elif host.get("options") != {}: + result_stack = colored( + " (*) General options: \n", "green", attrs=["bold",] + ) + for key, value in six.iteritems(host.get("options")): + if isinstance(value, type([])): + result_stack += "\t {0}: ".format( + colored(key, "magenta") + ) + result_stack += ', '.join(value) + result_stack += "\n" + else: + result_stack += "\t {0}: {1}\n".format( + colored(key, "magenta"), + value, + ) + result_stack = result_stack[0:-1] + "\n" + result += result_stack + return result + @command('list') -def list(config=None): +@arg('with_tags', '-t', action='store_true', dest='with_tags', help='displays tags for each host') +def list(with_tags=False,config=None): """ Lists all hosts from ssh config. """ @@ -189,64 +265,42 @@ def list(config=None): try: result = colored('Listing entries:', 'white', attrs=["bold", ]) + "\n\n" - result_stack = "" for host in storm_.list_entries(True): + result += format_host(host,storm_.defaults,with_tags) + if len(result) != 0: + print(get_formatted_message(result, "")) + except Exception as error: + print(get_formatted_message(str(error), 'error'), file=sys.stderr) - if host.get("type") == 'entry': - if not host.get("host") == "*": - result += " {0} -> {1}@{2}:{3}".format( - colored(host["host"], 'green', attrs=["bold", ]), - host.get("options").get( - "user", get_default("user", storm_.defaults) - ), - host.get("options").get( - "hostname", "[hostname_not_specified]" - ), - host.get("options").get( - "port", get_default("port", storm_.defaults) - ) - ) - - extra = False - for key, value in six.iteritems(host.get("options")): - - if not key in ["user", "hostname", "port"]: - if not extra: - custom_options = colored( - '\n\t[custom options] ', 'white' - ) - result += " {0}".format(custom_options) - extra = True - - if isinstance(value, collections.Sequence): - if isinstance(value, builtins.list): - value = ",".join(value) - - result += "{0}={1} ".format(key, value) - if extra: - result = result[0:-1] - - result += "\n\n" - else: - result_stack = colored( - " (*) General options: \n", "green", attrs=["bold",] - ) - for key, value in six.iteritems(host.get("options")): - if isinstance(value, type([])): - result_stack += "\t {0}: ".format( - colored(key, "magenta") - ) - result_stack += ', '.join(value) - result_stack += "\n" - else: - result_stack += "\t {0}: {1}\n".format( - colored(key, "magenta"), - value, - ) - result_stack = result_stack[0:-1] + "\n" +@command('list-tag') +@arg('tags', nargs='*', default=[], type=str, help='a tag name to filter displayed hosts, if no tag is supplied all tags are displayed') +def list(tags,config=None): + """ + Lists hosts from ssh config with a specific TAG or all tags. + """ + storm_ = get_storm_instance(config) - result += result_stack - print(get_formatted_message(result, "")) + try: + result = "" + input_tags = ["@" + tag for tag in tags] + input_tags = set(input_tags) # remove duplicates + all_tags = storm_.ssh_config.hosts_per_tag.keys() + + if len(input_tags) == 0: # if tags given, display all + found_tags = all_tags + else: + found_tags = set() + for tag in input_tags: + found_tags = found_tags | set(filter(lambda existing_tag: tag in existing_tag, all_tags)) + + + + for tag in found_tags: + result += colored('Listing entries for tag', 'white', attrs=["bold", ]) + " {0}".format(tag) + "\n\n" + for host in storm_.ssh_config.hosts_per_tag[tag]: + result += format_host(host,storm_.defaults,with_tags=True,compact_tags=True) + if len(result) != 0: + print(get_formatted_message(result, "")) except Exception as error: print(get_formatted_message(str(error), 'error'), file=sys.stderr) diff --git a/storm/parsers/ssh_config_parser.py b/storm/parsers/ssh_config_parser.py index 20e8400..4b4414e 100644 --- a/storm/parsers/ssh_config_parser.py +++ b/storm/parsers/ssh_config_parser.py @@ -21,7 +21,7 @@ def parse(self, file_obj): @type file_obj: file """ order = 1 - host = {"host": ['*'], "config": {}, } + host = {"host": ['*'], "config": {}, "tags" : [] } for line in file_obj: line = line.rstrip('\n').lstrip() if line == '': @@ -34,7 +34,7 @@ def parse(self, file_obj): order += 1 continue - if line.startswith('#'): + if line.startswith('#') and not line.startswith('#@'): self._config.append({ 'type': 'comment', 'value': line, @@ -53,6 +53,11 @@ def parse(self, file_obj): else: key, value = line.split('=', 1) key = key.strip().lower() + elif line.lower().strip().startswith('#@'): + line_with_tags=line[1:] + tags_re = re.compile(r"@\S+") + match = tags_re.findall(line_with_tags) + key, value = "tags", match else: # find first whitespace, and split there i = 0 @@ -62,6 +67,7 @@ def parse(self, file_obj): raise Exception('Unparsable line: %r' % line) key = line[:i].lower() value = line[i:].lstrip() + if key == 'host': self._config.append(host) value = value.split() @@ -69,7 +75,8 @@ def parse(self, file_obj): key: value, 'config': {}, 'type': 'entry', - 'order': order + 'order': order, + 'tags': [], } order += 1 elif key in ['identityfile', 'localforward', 'remoteforward']: @@ -77,11 +84,12 @@ def parse(self, file_obj): host['config'][key].append(value) else: host['config'][key] = [value] + elif key == 'tags': + host['tags'] += value elif key not in host['config']: host['config'].update({key: value}) self._config.append(host) - class ConfigParser(object): """ Config parser for ~/.ssh/config files. @@ -102,6 +110,7 @@ def __init__(self, ssh_config_file=None): chmod(self.ssh_config_file, 0o600) self.config_data = [] + self.hosts_per_tag = {} def get_default_ssh_config_file(self): return expanduser("~/.ssh/config") @@ -125,6 +134,7 @@ def load(self): 'options': entry.get("config"), 'type': 'entry', 'order': entry.get("order", 0), + 'tags': entry.get("tags") } if len(entry["host"]) > 1: @@ -132,6 +142,12 @@ def load(self): 'host': " ".join(entry["host"]), }) + for tag in host_item['tags']: + if not tag in self.hosts_per_tag: + self.hosts_per_tag[tag] = [] + if not host_item in self.hosts_per_tag[tag]: + self.hosts_per_tag[tag].append(host_item) + # minor bug in paramiko.SSHConfig that duplicates #"Host *" entries. if entry.get("config") and len(entry.get("config")) > 0: diff --git a/tests.py b/tests.py index bfee360..59a8d30 100644 --- a/tests.py +++ b/tests.py @@ -34,6 +34,7 @@ ## override as per host ## Host server1 + #@private HostName server1.cyberciti.biz User nixcraft Port 4242 @@ -42,6 +43,7 @@ ## Home nas server ## Host nas01 + #@private HostName 192.168.1.100 User root IdentityFile ~/.ssh/nas01.key @@ -55,6 +57,7 @@ ## Login to internal lan server at 192.168.0.251 via our public uk office ssh based gateway using ## ## $ ssh uk.gw.lan ## Host uk.gw.lan uk.lan + #@uk @gw HostName 192.168.0.251 User nixcraft ProxyCommand ssh nixcraft@gateway.uk.cyberciti.biz nc %h %p 2> /dev/null @@ -94,11 +97,11 @@ def run_cmd(self, cmd): return out, err, rc def test_list_command(self): - out, err, rc = self.run_cmd('list {0}'.format(self.config_arg)) + out, err, rc = self.run_cmd('list {0} {1}'.format(self.config_arg, '-t')) self.assertTrue(out.startswith(b" Listing entries:\n\n")) - hosts, custom_options = [ + hosts, custom_options, tags = [ "aws.apache -> wwwdata@1.2.3.4:22", "nas01 -> root@192.168.1.100:22", "proxyus -> breakfree@vps1.cyberciti.biz:22", @@ -111,6 +114,13 @@ def test_list_command(self): "localforward=3128 127.0.0.1:3128", "[custom options] identityfile=/nfs/shared/users/nixcraft/keys/server1/id_rsa,/tmp/x.rsa", "[custom options] proxycommand=ssh nixcraft@gateway.uk.cyberciti.biz nc %h %p 2> /dev/null", + ], [ + "", + "[tags] @private", + "", + "", + "[tags] @private", + "[tags] @uk, @gw", ] general_options = { @@ -131,6 +141,9 @@ def test_list_command(self): for custom_option in custom_options: self.assertIn(custom_option.encode('ascii'), out) + for tag in tags: + self.assertIn(tag.encode('ascii'), out) + for general_option, value in six.iteritems(general_options): self.assertIn("{0}: {1}".format(general_option, value).encode('ascii'), out) @@ -471,4 +484,4 @@ def tearDown(self): os.unlink('/tmp/ssh_config') if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file