diff --git a/storm/parsers/ssh_config_parser.py b/storm/parsers/ssh_config_parser.py index 20e8400..c9c2f10 100644 --- a/storm/parsers/ssh_config_parser.py +++ b/storm/parsers/ssh_config_parser.py @@ -13,6 +13,10 @@ class StormConfig(SSHConfig): + def __init__(self, *a, **kw): + super(StormConfig, self).__init__(*a, **kw) + self.lower_to_original = {} + def parse(self, file_obj): """ Read an OpenSSH config from the given file object. @@ -49,10 +53,10 @@ def parse(self, file_obj): if line.lower().strip().startswith('proxycommand'): proxy_re = re.compile(r"^(proxycommand)\s*=*\s*(.*)", re.I) match = proxy_re.match(line) - key, value = match.group(1).lower(), match.group(2) + key, value = match.group(1), match.group(2) else: key, value = line.split('=', 1) - key = key.strip().lower() + key = key.strip() else: # find first whitespace, and split there i = 0 @@ -60,8 +64,10 @@ def parse(self, file_obj): i += 1 if i == len(line): raise Exception('Unparsable line: %r' % line) - key = line[:i].lower() + key = line[:i] value = line[i:].lstrip() + self.lower_to_original[key.lower()] = key + key = key.lower() if key == 'host': self._config.append(host) value = value.split() @@ -92,6 +98,7 @@ def __init__(self, ssh_config_file=None): ssh_config_file = self.get_default_ssh_config_file() self.defaults = {} + self.lower_to_original = {} self.ssh_config_file = ssh_config_file @@ -111,6 +118,7 @@ def load(self): with open(self.ssh_config_file) as fd: config.parse(fd) + self.lower_to_original = config.lower_to_original for entry in config.__dict__.get("_config"): if entry.get("host") == ["*"]: @@ -218,12 +226,12 @@ def dump(self): sub_content = "" for value_ in value: sub_content += " {0} {1}\n".format( - key, value_ + self.lower_to_original.get(key) or key, value_ ) host_item_content += sub_content else: host_item_content += " {0} {1}\n".format( - key, value + self.lower_to_original.get(key) or key, value ) file_content += host_item_content diff --git a/tests.py b/tests.py index bfee360..00ad209 100644 --- a/tests.py +++ b/tests.py @@ -170,7 +170,7 @@ def test_advanced_add(self): with open(self.config_file) as f: # check that property is really flushed out to the config? content = f.read().encode('ascii') - self.assertIn(b'identityfile "/tmp/idfilecheck.rsa"', content) + self.assertIn(b'IdentityFile "/tmp/idfilecheck.rsa"', content) self.assertIn(b"stricthostkeychecking yes", content) self.assertIn(b"userknownhostsfile /dev/advanced_test", content) @@ -184,7 +184,7 @@ def test_add_with_idfile(self): with open(self.config_file) as f: content = f.read().encode('ascii') - self.assertIn(b'identityfile "/tmp/idfileonlycheck.rsa"', content) + self.assertIn(b'IdentityFile "/tmp/idfileonlycheck.rsa"', content) def test_basic_edit(self): out, err, rc = self.run_cmd('edit aws.apache basic_edit_check@10.20.30.40 {0}'.format(self.config_arg)) @@ -215,8 +215,8 @@ def test_update(self): with open(self.config_file) as f: content = f.read().encode('ascii') - self.assertIn(b"user daghan", content) # see daghan: http://instagram.com/p/lfPMW_qVja - self.assertIn(b"port 42000", content) + self.assertIn(b"User daghan", content) # see daghan: http://instagram.com/p/lfPMW_qVja + self.assertIn(b"Port 42000", content) def test_update_regex(self):