@@ -40,14 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
4040import org .apache .spark .util .collection .CompactBuffer
4141
4242class HashedRelationSuite extends SharedSparkSession {
43+ val umm = new UnifiedMemoryManager (
44+ new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
45+ Long .MaxValue ,
46+ Long .MaxValue / 2 ,
47+ 1 )
4348
44- val mm = new TaskMemoryManager (
45- new UnifiedMemoryManager (
46- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
47- Long .MaxValue ,
48- Long .MaxValue / 2 ,
49- 1 ),
50- 0 )
49+ val mm = new TaskMemoryManager (umm, 0 )
5150
5251 val rand = new Random (100 )
5352
@@ -64,6 +63,11 @@ class HashedRelationSuite extends SharedSparkSession {
6463 val sparseRows = sparseArray.map(i => projection(InternalRow (i.toLong)).copy())
6564 val randomRows = randomArray.map(i => projection(InternalRow (i.toLong)).copy())
6665
66+ protected override def afterEach (): Unit = {
67+ super .afterEach()
68+ assert(umm.executionMemoryUsed === 0 )
69+ }
70+
6771 test(" UnsafeHashedRelation" ) {
6872 val schema = StructType (StructField (" a" , IntegerType , true ) :: Nil )
6973 val data = Array (InternalRow (0 ), InternalRow (1 ), InternalRow (2 ), InternalRow (2 ))
@@ -87,6 +91,7 @@ class HashedRelationSuite extends SharedSparkSession {
8791 val out = new ObjectOutputStream (os)
8892 hashed.asInstanceOf [UnsafeHashedRelation ].writeExternal(out)
8993 out.flush()
94+ hashed.close()
9095 val in = new ObjectInputStream (new ByteArrayInputStream (os.toByteArray))
9196 val hashed2 = new UnsafeHashedRelation ()
9297 hashed2.readExternal(in)
@@ -108,19 +113,13 @@ class HashedRelationSuite extends SharedSparkSession {
108113 }
109114
110115 test(" test serialization empty hash map" ) {
111- val taskMemoryManager = new TaskMemoryManager (
112- new UnifiedMemoryManager (
113- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
114- Long .MaxValue ,
115- Long .MaxValue / 2 ,
116- 1 ),
117- 0 )
118- val binaryMap = new BytesToBytesMap (taskMemoryManager, 1 , 1 )
116+ val binaryMap = new BytesToBytesMap (mm, 1 , 1 )
119117 val os = new ByteArrayOutputStream ()
120118 val out = new ObjectOutputStream (os)
121119 val hashed = new UnsafeHashedRelation (1 , 1 , binaryMap)
122120 hashed.writeExternal(out)
123121 out.flush()
122+ hashed.close()
124123 val in = new ObjectInputStream (new ByteArrayInputStream (os.toByteArray))
125124 val hashed2 = new UnsafeHashedRelation ()
126125 hashed2.readExternal(in)
@@ -149,9 +148,10 @@ class HashedRelationSuite extends SharedSparkSession {
149148 assert(row.getLong(0 ) === i)
150149 assert(row.getInt(1 ) === i + 1 )
151150 }
151+ longRelation.close()
152152
153153 val longRelation2 = LongHashedRelation (rows.iterator ++ rows.iterator, key, 100 , mm)
154- .asInstanceOf [LongHashedRelation ]
154+ .asInstanceOf [LongHashedRelation ]
155155 assert(! longRelation2.keyIsUnique)
156156 (0 until 100 ).foreach { i =>
157157 val rows = longRelation2.get(i).toArray
@@ -166,6 +166,7 @@ class HashedRelationSuite extends SharedSparkSession {
166166 val out = new ObjectOutputStream (os)
167167 longRelation2.writeExternal(out)
168168 out.flush()
169+ longRelation2.close()
169170 val in = new ObjectInputStream (new ByteArrayInputStream (os.toByteArray))
170171 val relation = new LongHashedRelation ()
171172 relation.readExternal(in)
@@ -181,19 +182,12 @@ class HashedRelationSuite extends SharedSparkSession {
181182 }
182183
183184 test(" LongToUnsafeRowMap with very wide range" ) {
184- val taskMemoryManager = new TaskMemoryManager (
185- new UnifiedMemoryManager (
186- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
187- Long .MaxValue ,
188- Long .MaxValue / 2 ,
189- 1 ),
190- 0 )
191185 val unsafeProj = UnsafeProjection .create(Seq (BoundReference (0 , LongType , false )))
192186
193187 {
194188 // SPARK-16740
195189 val keys = Seq (0L , Long .MaxValue , Long .MaxValue )
196- val map = new LongToUnsafeRowMap (taskMemoryManager , 1 )
190+ val map = new LongToUnsafeRowMap (mm , 1 )
197191 keys.foreach { k =>
198192 map.append(k, unsafeProj(InternalRow (k)))
199193 }
@@ -210,7 +204,7 @@ class HashedRelationSuite extends SharedSparkSession {
210204 {
211205 // SPARK-16802
212206 val keys = Seq (Long .MaxValue , Long .MaxValue - 10 )
213- val map = new LongToUnsafeRowMap (taskMemoryManager , 1 )
207+ val map = new LongToUnsafeRowMap (mm , 1 )
214208 keys.foreach { k =>
215209 map.append(k, unsafeProj(InternalRow (k)))
216210 }
@@ -226,20 +220,13 @@ class HashedRelationSuite extends SharedSparkSession {
226220 }
227221
228222 test(" LongToUnsafeRowMap with random keys" ) {
229- val taskMemoryManager = new TaskMemoryManager (
230- new UnifiedMemoryManager (
231- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
232- Long .MaxValue ,
233- Long .MaxValue / 2 ,
234- 1 ),
235- 0 )
236223 val unsafeProj = UnsafeProjection .create(Seq (BoundReference (0 , LongType , false )))
237224
238225 val N = 1000000
239226 val rand = new Random
240227 val keys = (0 to N ).map(x => rand.nextLong()).toArray
241228
242- val map = new LongToUnsafeRowMap (taskMemoryManager , 10 )
229+ val map = new LongToUnsafeRowMap (mm , 10 )
243230 keys.foreach { k =>
244231 map.append(k, unsafeProj(InternalRow (k)))
245232 }
@@ -249,8 +236,9 @@ class HashedRelationSuite extends SharedSparkSession {
249236 val out = new ObjectOutputStream (os)
250237 map.writeExternal(out)
251238 out.flush()
239+ map.free()
252240 val in = new ObjectInputStream (new ByteArrayInputStream (os.toByteArray))
253- val map2 = new LongToUnsafeRowMap (taskMemoryManager , 1 )
241+ val map2 = new LongToUnsafeRowMap (mm , 1 )
254242 map2.readExternal(in)
255243
256244 val row = unsafeProj(InternalRow (0L )).copy()
@@ -276,19 +264,12 @@ class HashedRelationSuite extends SharedSparkSession {
276264 }
277265 i += 1
278266 }
279- map .free()
267+ map2 .free()
280268 }
281269
282270 test(" SPARK-24257: insert big values into LongToUnsafeRowMap" ) {
283- val taskMemoryManager = new TaskMemoryManager (
284- new UnifiedMemoryManager (
285- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
286- Long .MaxValue ,
287- Long .MaxValue / 2 ,
288- 1 ),
289- 0 )
290271 val unsafeProj = UnsafeProjection .create(Array [DataType ](StringType ))
291- val map = new LongToUnsafeRowMap (taskMemoryManager , 1 )
272+ val map = new LongToUnsafeRowMap (mm , 1 )
292273
293274 val key = 0L
294275 // the page array is initialized with length 1 << 17 (1M bytes),
@@ -343,6 +324,7 @@ class HashedRelationSuite extends SharedSparkSession {
343324 val rows = (0 until 100 ).map(i => unsafeProj(InternalRow (Int .int2long(i), i + 1 )).copy())
344325 val longRelation = LongHashedRelation (rows.iterator ++ rows.iterator, key, 100 , mm)
345326 val longRelation2 = ser.deserialize[LongHashedRelation ](ser.serialize(longRelation))
327+ longRelation.close()
346328 (0 until 100 ).foreach { i =>
347329 val rows = longRelation2.get(i).toArray
348330 assert(rows.length === 2 )
@@ -359,6 +341,7 @@ class HashedRelationSuite extends SharedSparkSession {
359341 unsafeHashed.asInstanceOf [UnsafeHashedRelation ].writeExternal(out)
360342 out.flush()
361343 val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation ](ser.serialize(unsafeHashed))
344+ unsafeHashed.close()
362345 val os2 = new ByteArrayOutputStream ()
363346 val out2 = new ObjectOutputStream (os2)
364347 unsafeHashed2.writeExternal(out2)
@@ -398,6 +381,7 @@ class HashedRelationSuite extends SharedSparkSession {
398381 thread2.join()
399382
400383 val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation ](ser.serialize(unsafeHashed))
384+ unsafeHashed.close()
401385 val os2 = new ByteArrayOutputStream ()
402386 val out2 = new ObjectOutputStream (os2)
403387 unsafeHashed2.writeExternal(out2)
@@ -452,18 +436,21 @@ class HashedRelationSuite extends SharedSparkSession {
452436 val hashedRelation = UnsafeHashedRelation (contiguousRows.iterator, singleKey, 1 , mm)
453437 val keyIterator = hashedRelation.keys()
454438 assert(keyIterator.map(key => key.getLong(0 )).toArray === contiguousArray)
439+ hashedRelation.close()
455440 }
456441
457442 test(" UnsafeHashedRelation: key set iterator on a sparse array of keys" ) {
458443 val hashedRelation = UnsafeHashedRelation (sparseRows.iterator, singleKey, 1 , mm)
459444 val keyIterator = hashedRelation.keys()
460445 assert(keyIterator.map(key => key.getLong(0 )).toArray === sparseArray)
446+ hashedRelation.close()
461447 }
462448
463449 test(" LongHashedRelation: key set iterator on a contiguous array of keys" ) {
464450 val longRelation = LongHashedRelation (contiguousRows.iterator, singleKey, 1 , mm)
465451 val keyIterator = longRelation.keys()
466452 assert(keyIterator.map(key => key.getLong(0 )).toArray === contiguousArray)
453+ longRelation.close()
467454 }
468455
469456 test(" LongToUnsafeRowMap: key set iterator on a contiguous array of keys" ) {
@@ -478,6 +465,7 @@ class HashedRelationSuite extends SharedSparkSession {
478465 rowMap.optimize()
479466 keyIterator = rowMap.keys()
480467 assert(keyIterator.map(key => key.getLong(0 )).toArray === contiguousArray)
468+ rowMap.free()
481469 }
482470
483471 test(" LongToUnsafeRowMap: key set iterator on a sparse array with equidistant keys" ) {
@@ -490,6 +478,7 @@ class HashedRelationSuite extends SharedSparkSession {
490478 rowMap.optimize()
491479 keyIterator = rowMap.keys()
492480 assert(keyIterator.map(_.getLong(0 )).toArray === sparseArray)
481+ rowMap.free()
493482 }
494483
495484 test(" LongToUnsafeRowMap: key set iterator on an array with a single key" ) {
@@ -530,6 +519,7 @@ class HashedRelationSuite extends SharedSparkSession {
530519 buffer.append(keyIterator.next().getLong(0 ))
531520 }
532521 assert(buffer === randomArray)
522+ rowMap.free()
533523 }
534524
535525 test(" LongToUnsafeRowMap: no explicit hasNext calls on the key iterator" ) {
@@ -560,6 +550,7 @@ class HashedRelationSuite extends SharedSparkSession {
560550 buffer.append(keyIterator.next().getLong(0 ))
561551 }
562552 assert(buffer === randomArray)
553+ rowMap.free()
563554 }
564555
565556 test(" LongToUnsafeRowMap: call hasNext at the end of the iterator" ) {
@@ -577,6 +568,7 @@ class HashedRelationSuite extends SharedSparkSession {
577568 assert(keyIterator.map(key => key.getLong(0 )).toArray === sparseArray)
578569 assert(keyIterator.hasNext == false )
579570 assert(keyIterator.hasNext == false )
571+ rowMap.free()
580572 }
581573
582574 test(" LongToUnsafeRowMap: random sequence of hasNext and next() calls on the key iterator" ) {
@@ -607,6 +599,7 @@ class HashedRelationSuite extends SharedSparkSession {
607599 }
608600 }
609601 assert(buffer === randomArray)
602+ rowMap.free()
610603 }
611604
612605 test(" HashJoin: packing and unpacking with the same key type in a LongType" ) {
@@ -661,6 +654,7 @@ class HashedRelationSuite extends SharedSparkSession {
661654 assert(hashed.keys().isEmpty)
662655 assert(hashed.keyIsUnique)
663656 assert(hashed.estimatedSize == 0 )
657+ hashed.close()
664658 }
665659
666660 test(" SPARK-32399: test methods related to key index" ) {
@@ -739,20 +733,14 @@ class HashedRelationSuite extends SharedSparkSession {
739733 val actualValues = row.map(_._2.getInt(1 ))
740734 assert(actualValues === expectedValues)
741735 }
736+ unsafeRelation.close()
742737 }
743738
744739 test(" LongToUnsafeRowMap support ignoresDuplicatedKey" ) {
745- val taskMemoryManager = new TaskMemoryManager (
746- new UnifiedMemoryManager (
747- new SparkConf ().set(MEMORY_OFFHEAP_ENABLED .key, " false" ),
748- Long .MaxValue ,
749- Long .MaxValue / 2 ,
750- 1 ),
751- 0 )
752740 val unsafeProj = UnsafeProjection .create(Seq (BoundReference (0 , LongType , false )))
753741 val keys = Seq (1L , 1L , 1L )
754742 Seq (true , false ).foreach { ignoresDuplicatedKey =>
755- val map = new LongToUnsafeRowMap (taskMemoryManager , 1 , ignoresDuplicatedKey)
743+ val map = new LongToUnsafeRowMap (mm , 1 , ignoresDuplicatedKey)
756744 keys.foreach { k =>
757745 map.append(k, unsafeProj(InternalRow (k)))
758746 }
0 commit comments