diff --git a/src/main/java/jnr/unixsocket/impl/AbstractNativeDatagramChannel.java b/src/main/java/jnr/unixsocket/impl/AbstractNativeDatagramChannel.java index 528e0f4..d4d04e0 100644 --- a/src/main/java/jnr/unixsocket/impl/AbstractNativeDatagramChannel.java +++ b/src/main/java/jnr/unixsocket/impl/AbstractNativeDatagramChannel.java @@ -18,6 +18,7 @@ package jnr.unixsocket.impl; +import jnr.constants.platform.Shutdown; import jnr.enxio.channels.Native; import jnr.enxio.channels.NativeSelectableChannel; import jnr.enxio.channels.NativeSelectorProvider; @@ -52,6 +53,8 @@ public final int getFD() { @Override protected void implCloseSelectableChannel() throws IOException { + // Shutdown to interrupt any potentially blocked threads. This is necessary on Linux. + Native.shutdown(getFD(), SHUT_RD); Native.close(common.getFD()); } @@ -80,4 +83,5 @@ public long write(ByteBuffer[] srcs, int offset, return common.write(srcs, offset, length); } + private static final int SHUT_RD = Shutdown.SHUT_RD.intValue(); } diff --git a/src/test/java/jnr/unixsocket/UnixDatagramChannelTest.java b/src/test/java/jnr/unixsocket/UnixDatagramChannelTest.java index f6c2864..7bf4cd1 100644 --- a/src/test/java/jnr/unixsocket/UnixDatagramChannelTest.java +++ b/src/test/java/jnr/unixsocket/UnixDatagramChannelTest.java @@ -1,7 +1,11 @@ package jnr.unixsocket; import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.file.Files; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; import org.junit.Test; @@ -66,4 +70,48 @@ public void testAbstractNamespace() throws Exception { assertEquals("local socket path", ABSTRACT, ch.getLocalSocketAddress().path()); } + @Test + public void testInterruptRead() throws Exception { + int readTimeoutInMilliseconds = 5000; + + UnixDatagramChannel ch = UnixDatagramChannel.open(); + ch.bind(null); + + CountDownLatch readStartLatch = new CountDownLatch(1); + AtomicReference thrownOnThread = new AtomicReference(); + + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + readStartLatch.countDown(); + ByteBuffer buffer = ByteBuffer.allocate(1 << 16); + ch.receive(buffer); + } catch (IOException e) { + if (!e.getMessage().equals("Bad file descriptor")) { + thrownOnThread.set(e); + } + } + } + }; + + Thread readThread = new Thread(runnable); + + readThread.setDaemon(true); + + long startTime = System.nanoTime(); + readThread.start(); + readStartLatch.await(); + Thread.sleep(100); // Wait for the thread to call receive() + ch.close(); + readThread.join(); + long stopTime = System.nanoTime(); + + long duration = stopTime - startTime; + long durationInMilliseconds = duration / 1_000_000; + + assertTrue("read() was not interrupted by close() before read() timed out", durationInMilliseconds < readTimeoutInMilliseconds); + assertEquals("read() threw an exception", null, thrownOnThread.get()); + } + }