11
11
import java .util .Collections ;
12
12
import java .util .List ;
13
13
import java .util .Map ;
14
+ import java .util .concurrent .CompletableFuture ;
14
15
import java .util .concurrent .CompletionException ;
15
16
import java .util .concurrent .CompletionStage ;
17
+ import java .util .concurrent .Executor ;
16
18
import java .util .concurrent .TimeUnit ;
17
19
import java .util .concurrent .TimeoutException ;
20
+ import java .util .function .BiConsumer ;
21
+ import java .util .function .BiFunction ;
18
22
import java .util .function .Function ;
19
23
import java .util .function .Supplier ;
20
24
import java .util .stream .Collectors ;
21
25
import java .util .stream .IntStream ;
26
+ import java .util .stream .Stream ;
22
27
23
28
import org .junit .Assert ;
24
29
import org .junit .Test ;
31
36
32
37
@ RunWith (Parameterized .class )
33
38
public class CombinatorsTest {
39
+ private static final Executor ASYNC = CompletableFuture ::runAsync ;
40
+
34
41
private static class TestException extends RuntimeException {
35
42
private static final long serialVersionUID = 1L ;
36
43
}
@@ -155,9 +162,9 @@ public void testAllOfError() {
155
162
Arrays .asList (
156
163
getCompletedStage (1 ),
157
164
getExceptionalStage (new TestException ()));
158
- assertError (Combinators .allOf (futures ));
159
- assertError (Combinators .collect (futures ));
160
- assertError (Combinators .collect (futures , Collectors .toList ()));
165
+ CombinatorsTest . assertError (Combinators .allOf (futures ));
166
+ CombinatorsTest . assertError (Combinators .collect (futures ));
167
+ CombinatorsTest . assertError (Combinators .collect (futures , Collectors .toList ()));
161
168
}
162
169
163
170
@ Test
@@ -173,15 +180,15 @@ public void testAllOfErrorNoShortCircuit() {
173
180
final CompletionStage <List <Integer >> collCollect =
174
181
Combinators .collect (futures , Collectors .toList ());
175
182
176
- assertIncomplete (voidAll );
177
- assertIncomplete (collAll );
178
- assertIncomplete (collCollect );
183
+ CombinatorsTest . assertIncomplete (voidAll );
184
+ CombinatorsTest . assertIncomplete (collAll );
185
+ CombinatorsTest . assertIncomplete (collCollect );
179
186
180
187
delayed .complete (1 );
181
188
182
- assertError (voidAll );
183
- assertError (collAll );
184
- assertError (collCollect );
189
+ CombinatorsTest . assertError (voidAll );
190
+ CombinatorsTest . assertError (collAll );
191
+ CombinatorsTest . assertError (collCollect );
185
192
}
186
193
187
194
@ Test
@@ -208,7 +215,7 @@ public void testKeyedAllError() {
208
215
}
209
216
return getCompletedStage (i );
210
217
}));
211
- assertError (Combinators .keyedAll (stageMap ));
218
+ CombinatorsTest . assertError (Combinators .keyedAll (stageMap ));
212
219
}
213
220
214
221
@ Test
@@ -220,26 +227,84 @@ public void testKeyedAllErrorNoShortCircuit() {
220
227
final CompletionStage <Map <Integer , Integer >> fut = Combinators .keyedAll (stageMap );
221
228
int i = 0 ;
222
229
for (final CompletableStage <Integer > future : stageMap .values ()) {
223
- assertIncomplete (fut );
230
+ CombinatorsTest . assertIncomplete (fut );
224
231
if (i == 3 ) {
225
232
future .completeExceptionally (new TestException ());
226
233
} else {
227
234
future .complete (i );
228
235
}
229
236
i ++;
230
237
}
231
- assertError (fut );
238
+ CombinatorsTest .assertError (fut );
239
+ }
240
+
241
+ /**
242
+ * Test that CompletionStage methods which depend on both of two stages always wait for both
243
+ * stages to complete.
244
+ */
245
+ @ Test
246
+ public void testCombineEtAl () {
247
+ for (final BiFunction <CompletionStage <?>, CompletionStage <?>, CompletionStage <Void >> combineMethod : Stream
248
+ .<BiFunction <CompletionStage <?>, CompletionStage <?>, CompletionStage <Void >>>of (
249
+ (a , b ) -> a .thenCombine (b , CombinatorsTest .voidFunction ()),
250
+ (a , b ) -> a .thenCombineAsync (b , CombinatorsTest .voidFunction ()),
251
+ (a , b ) -> a .thenCombineAsync (b , CombinatorsTest .voidFunction (), ASYNC ),
252
+ (a , b ) -> a .thenAcceptBoth (b , CombinatorsTest .voidConsumer ()),
253
+ (a , b ) -> a .thenAcceptBothAsync (b , CombinatorsTest .voidConsumer ()),
254
+ (a , b ) -> a .thenAcceptBothAsync (b , CombinatorsTest .voidConsumer (), ASYNC ),
255
+ (a , b ) -> a .runAfterBoth (b , CombinatorsTest .voidRunnable ()),
256
+ (a , b ) -> a .runAfterBothAsync (b , CombinatorsTest .voidRunnable ()),
257
+ (a , b ) -> a .runAfterBothAsync (b , CombinatorsTest .voidRunnable (), ASYNC ))
258
+ // include all the functions in reverse as well, switch a and b in argument order
259
+ .flatMap (function -> Stream .of (function , (a , b ) -> function .apply (b , a )))
260
+ .collect (Collectors .toList ())) {
261
+ {
262
+ final CompletionStage <String > doneNormal = getCompletedStage ("a" );
263
+ final CompletableStage <String > incompleteNormal = getCompletableStage ();
264
+
265
+ final CompletionStage <Void > combine = combineMethod .apply (doneNormal , incompleteNormal );
266
+
267
+ CombinatorsTest .assertIncomplete (combine );
268
+ incompleteNormal .complete ("b" );
269
+ TestUtil .join (combine );
270
+ }
271
+ {
272
+ final CompletionStage <String > doneExceptional = getExceptionalStage (new TestException ());
273
+ final CompletableStage <String > incompleteNormal = getCompletableStage ();
274
+
275
+ final CompletionStage <Void > combine =
276
+ combineMethod .apply (doneExceptional , incompleteNormal );
277
+
278
+ CombinatorsTest .assertIncomplete (combine );
279
+ incompleteNormal .complete ("b" );
280
+ CombinatorsTest .assertError (combine );
281
+ }
282
+ }
283
+ }
284
+
285
+ private static <T , V > BiFunction <T , V , Void > voidFunction () {
286
+ return (ig1 , ig2 ) -> null ;
287
+ }
288
+
289
+ private static <T , V > BiConsumer <T , V > voidConsumer () {
290
+ return (ig1 , ig2 ) -> {
291
+ };
292
+ }
293
+
294
+ private static Runnable voidRunnable () {
295
+ return () -> {
296
+ };
232
297
}
233
298
234
- private <T > void assertError (final CompletionStage <T > stage ) {
299
+ private static <T > void assertError (final CompletionStage <T > stage ) {
235
300
try {
236
301
TestUtil .join (stage );
237
302
} catch (final CompletionException e ) {
238
303
Assert .assertTrue (e .getCause () instanceof TestException );
239
304
}
240
305
}
241
306
242
- private <T > void assertIncomplete (final CompletionStage <T > stage ) {
307
+ private static <T > void assertIncomplete (final CompletionStage <T > stage ) {
243
308
try {
244
309
TestUtil .join (stage , 20 , TimeUnit .MILLISECONDS );
245
310
Assert .fail ("not all futures complete, get should timeout" );
0 commit comments