Skip to content

Commit 7effadc

Browse files
authoredDec 20, 2021
exec_command segfault partial fix (#280)
* Mostly undo 73a2be9 to fix _other_ excec_command segfaults * Replace % with format where touched * Update test_exec_command to test repeated calls to ensure the object remains usable
1 parent 6b053b5 commit 7effadc

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed
 
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Improved ``channel.exec_command`` to always use a newly created ``ssh_channel`` to avoid
2+
segfaults on repeated calls -- by :user:`Qalthos`

‎src/pylibsshext/channel.pyx

+25-12
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,10 @@ cdef class Channel:
8989
def poll(self, timeout=-1, stderr=0):
9090
if timeout < 0:
9191
rc = libssh.ssh_channel_poll(self._libssh_channel, stderr)
92-
if rc == libssh.SSH_ERROR:
93-
raise LibsshChannelException("Failed to poll channel: [%d]" % rc)
9492
else:
9593
rc = libssh.ssh_channel_poll_timeout(self._libssh_channel, timeout, stderr)
96-
if rc == libssh.SSH_ERROR:
97-
raise LibsshChannelException("Failed to poll channel: [%d]" % rc)
94+
if rc == libssh.SSH_ERROR:
95+
raise LibsshChannelException("Failed to poll channel: [{0}]".format(rc))
9896
return rc
9997

10098
def read_nonblocking(self, size=1024, stderr=0):
@@ -113,7 +111,10 @@ cdef class Channel:
113111
return self.read_nonblocking(size=size, stderr=stderr)
114112

115113
def write(self, data):
116-
return libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data))
114+
written = libssh.ssh_channel_write(self._libssh_channel, PyBytes_AS_STRING(data), len(data))
115+
if written == libssh.SSH_ERROR:
116+
raise LibsshChannelException("Failed to write to ssh channel")
117+
return written
117118

118119
def sendall(self, data):
119120
return self.write(data)
@@ -139,23 +140,35 @@ cdef class Channel:
139140
return response
140141

141142
def exec_command(self, command):
142-
rc = libssh.ssh_channel_request_exec(self._libssh_channel, command.encode("utf-8"))
143+
# request_exec requires a fresh channel each run, so do not use the existing channel
144+
cdef libssh.ssh_channel channel = libssh.ssh_channel_new(self._libssh_session)
145+
if channel is NULL:
146+
raise MemoryError
147+
148+
rc = libssh.ssh_channel_open_session(channel)
149+
if rc != libssh.SSH_OK:
150+
libssh.ssh_channel_free(channel)
151+
raise LibsshChannelException("Failed to open_session: [{0}]".format(rc))
143152

153+
rc = libssh.ssh_channel_request_exec(channel, command.encode("utf-8"))
144154
if rc != libssh.SSH_OK:
145-
self.close()
146-
raise CalledProcessError()
155+
libssh.ssh_channel_close(channel)
156+
libssh.ssh_channel_free(channel)
157+
raise LibsshChannelException("Failed to execute command [{0}]: [{1}]".format(command, rc))
147158
result = CompletedProcess(args=command, returncode=-1, stdout=b'', stderr=b'')
148159

149160
cdef callbacks.ssh_channel_callbacks_struct cb
150161
memset(&cb, 0, sizeof(cb))
151162
cb.channel_data_function = <callbacks.ssh_channel_data_callback>&_process_outputs
152163
cb.userdata = <void *>result
153164
callbacks.ssh_callbacks_init(&cb)
154-
callbacks.ssh_set_channel_callbacks(self._libssh_channel, &cb)
155-
156-
libssh.ssh_channel_send_eof(self._libssh_channel)
165+
callbacks.ssh_set_channel_callbacks(channel, &cb)
157166

158-
result.returncode = self.get_channel_exit_status()
167+
libssh.ssh_channel_send_eof(channel)
168+
result.returncode = libssh.ssh_channel_get_exit_status(channel)
169+
if channel is not NULL:
170+
libssh.ssh_channel_close(channel)
171+
libssh.ssh_channel_free(channel)
159172

160173
return result
161174

‎src/pylibsshext/includes/libssh.pxd

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from libc.stdint cimport uint32_t
2121

2222
cdef extern from "libssh/libssh.h" nogil:
2323

24-
cpdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)"
24+
cdef const char * libssh_version "SSH_STRINGIFY(LIBSSH_VERSION)"
2525

2626
cdef struct ssh_session_struct:
2727
pass

‎tests/unit/channel_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ def ssh_channel(ssh_client_session):
3232
'Ref: https://github.com/ansible/pylibssh/issues/57', # noqa: WPS326
3333
strict=False,
3434
)
35-
@pytest.mark.forked() # noqa: PT023 -- it's unclear if braces are needed here
35+
@pytest.mark.forked
3636
def test_exec_command(ssh_channel):
3737
"""Test getting the output of a remotely executed command."""
3838
u_cmd_out = ssh_channel.exec_command('echo -n Hello World').stdout.decode()
3939
assert u_cmd_out == u'Hello World' # noqa: WPS302
40+
# Test that repeated calls to exec_command do not segfault.
41+
u_cmd_out = ssh_channel.exec_command('echo -n Hello Again').stdout.decode()
42+
assert u_cmd_out == u'Hello Again' # noqa: WPS302
4043

4144

4245
def test_double_close(ssh_channel):

0 commit comments

Comments
 (0)
Please sign in to comment.