diff --git a/src/rawposix/src/net_calls.rs b/src/rawposix/src/net_calls.rs index fff818da3..09df1675a 100644 --- a/src/rawposix/src/net_calls.rs +++ b/src/rawposix/src/net_calls.rs @@ -948,10 +948,23 @@ pub extern "C" fn epoll_wait_syscall( // Get the underfd of type FDKIND_KERNEL to the vitual fd // Details see documentation on fdtables/epoll_get_underfd_hashmap.md - let epfd = *fdtables::epoll_get_underfd_hashmap(cageid, epfd_arg) - .unwrap() - .get(&FDKIND_KERNEL) - .unwrap(); + let fd_map = match fdtables::epoll_get_underfd_hashmap(cageid, epfd_arg) { + Ok(map) => map, + Err(_) => { + return syscall_error(Errno::EBADF, "epoll_wait_syscall", "invalid epoll fd"); + } + }; + + let epfd = match fd_map.get(&FDKIND_KERNEL) { + Some(fd) => *fd, + None => { + return syscall_error( + Errno::EBADF, + "epoll_wait_syscall", + "missing kernel epoll fd", + ); + } + }; // Convert arguments let maxevents = sc_convert_sysarg_to_i32(maxevents_arg, maxevents_cageid, cageid); @@ -1030,13 +1043,24 @@ pub extern "C" fn epoll_wait_syscall( // Loop over virtual epollfd to find corresponding mapping relationship between kernel fd and virtual fd for i in 0..ret as usize { let ret_kernelfd = kernel_events[i].u64; + let epollmapping = REAL_EPOLL_MAP.lock(); - let ret_virtualfd = epollmapping - .get(&(epfd)) - .and_then(|kernel_map| kernel_map.get(&(ret_kernelfd as i32)).copied()); - // Write back to user's buffer: store virtual fd in the u64 data field - events[i].u64 = ret_virtualfd.unwrap() as u64; + let ret_virtualfd = match epollmapping + .get(&epfd) + .and_then(|kernel_map| kernel_map.get(&(ret_kernelfd as i32)).copied()) + { + Some(vfd) => vfd, + None => { + return syscall_error( + Errno::EBADF, + "epoll", + "could not translate kernel fd to virtual fd", + ); + } + }; + + events[i].u64 = ret_virtualfd as u64; events[i].events = kernel_events[i].events; } return ret; diff --git a/tests/unit-tests/networking_tests/deterministic/epoll_badfd.c b/tests/unit-tests/networking_tests/deterministic/epoll_badfd.c new file mode 100644 index 000000000..034d41f4a --- /dev/null +++ b/tests/unit-tests/networking_tests/deterministic/epoll_badfd.c @@ -0,0 +1,22 @@ +#include +#include +#include +#include + +int main(void) { + struct epoll_event events[4]; + + /* create epoll instance */ + int epfd = epoll_create1(0); + assert(epfd != -1); + + /* close it so fd becomes invalid */ + assert(close(epfd) == 0); + + /* epoll_wait on closed fd must fail with EBADF */ + int ret = epoll_wait(epfd, events, 4, 1000); + assert(ret == -1); + assert(errno == EBADF); + + return 0; +}