Skip to content

Commit a5e866f

Browse files
zhztheplayercloud-fan
authored andcommitted
[SPARK-54132][SQL][TESTS] Cover HashedRelation#close in HashedRelationSuite
### What changes were proposed in this pull request? Add the following code in `HashedRelationSuite`, to cover the API `HashedRelation#close` in the test suite. ```scala protected override def afterEach(): Unit = { super.afterEach() assert(umm.executionMemoryUsed === 0) } ``` ### Why are the changes needed? Doing this will: 1. Ensure `HashedRelation#close` is called in test code, to lower memory footprint and avoid memory leak when executing tests. 2. Ensure implementations of `HashedRelation#close` free the allocated memory blocks correctly. It's an individual effort to improve the test quality, but also a prerequisite task for #52817. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? It's a test PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52830 from zhztheplayer/wip-54132. Authored-by: Hongze Zhang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent daa83fc commit a5e866f

File tree

1 file changed

+39
-51
lines changed

1 file changed

+39
-51
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
4040
import org.apache.spark.util.collection.CompactBuffer
4141

4242
class HashedRelationSuite extends SharedSparkSession {
43+
val umm = new UnifiedMemoryManager(
44+
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
45+
Long.MaxValue,
46+
Long.MaxValue / 2,
47+
1)
4348

44-
val mm = new TaskMemoryManager(
45-
new UnifiedMemoryManager(
46-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
47-
Long.MaxValue,
48-
Long.MaxValue / 2,
49-
1),
50-
0)
49+
val mm = new TaskMemoryManager(umm, 0)
5150

5251
val rand = new Random(100)
5352

@@ -64,6 +63,11 @@ class HashedRelationSuite extends SharedSparkSession {
6463
val sparseRows = sparseArray.map(i => projection(InternalRow(i.toLong)).copy())
6564
val randomRows = randomArray.map(i => projection(InternalRow(i.toLong)).copy())
6665

66+
protected override def afterEach(): Unit = {
67+
super.afterEach()
68+
assert(umm.executionMemoryUsed === 0)
69+
}
70+
6771
test("UnsafeHashedRelation") {
6872
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
6973
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
@@ -87,6 +91,7 @@ class HashedRelationSuite extends SharedSparkSession {
8791
val out = new ObjectOutputStream(os)
8892
hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
8993
out.flush()
94+
hashed.close()
9095
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
9196
val hashed2 = new UnsafeHashedRelation()
9297
hashed2.readExternal(in)
@@ -108,19 +113,13 @@ class HashedRelationSuite extends SharedSparkSession {
108113
}
109114

110115
test("test serialization empty hash map") {
111-
val taskMemoryManager = new TaskMemoryManager(
112-
new UnifiedMemoryManager(
113-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
114-
Long.MaxValue,
115-
Long.MaxValue / 2,
116-
1),
117-
0)
118-
val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
116+
val binaryMap = new BytesToBytesMap(mm, 1, 1)
119117
val os = new ByteArrayOutputStream()
120118
val out = new ObjectOutputStream(os)
121119
val hashed = new UnsafeHashedRelation(1, 1, binaryMap)
122120
hashed.writeExternal(out)
123121
out.flush()
122+
hashed.close()
124123
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
125124
val hashed2 = new UnsafeHashedRelation()
126125
hashed2.readExternal(in)
@@ -149,9 +148,10 @@ class HashedRelationSuite extends SharedSparkSession {
149148
assert(row.getLong(0) === i)
150149
assert(row.getInt(1) === i + 1)
151150
}
151+
longRelation.close()
152152

153153
val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
154-
.asInstanceOf[LongHashedRelation]
154+
.asInstanceOf[LongHashedRelation]
155155
assert(!longRelation2.keyIsUnique)
156156
(0 until 100).foreach { i =>
157157
val rows = longRelation2.get(i).toArray
@@ -166,6 +166,7 @@ class HashedRelationSuite extends SharedSparkSession {
166166
val out = new ObjectOutputStream(os)
167167
longRelation2.writeExternal(out)
168168
out.flush()
169+
longRelation2.close()
169170
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
170171
val relation = new LongHashedRelation()
171172
relation.readExternal(in)
@@ -181,19 +182,12 @@ class HashedRelationSuite extends SharedSparkSession {
181182
}
182183

183184
test("LongToUnsafeRowMap with very wide range") {
184-
val taskMemoryManager = new TaskMemoryManager(
185-
new UnifiedMemoryManager(
186-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
187-
Long.MaxValue,
188-
Long.MaxValue / 2,
189-
1),
190-
0)
191185
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
192186

193187
{
194188
// SPARK-16740
195189
val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
196-
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
190+
val map = new LongToUnsafeRowMap(mm, 1)
197191
keys.foreach { k =>
198192
map.append(k, unsafeProj(InternalRow(k)))
199193
}
@@ -210,7 +204,7 @@ class HashedRelationSuite extends SharedSparkSession {
210204
{
211205
// SPARK-16802
212206
val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
213-
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
207+
val map = new LongToUnsafeRowMap(mm, 1)
214208
keys.foreach { k =>
215209
map.append(k, unsafeProj(InternalRow(k)))
216210
}
@@ -226,20 +220,13 @@ class HashedRelationSuite extends SharedSparkSession {
226220
}
227221

228222
test("LongToUnsafeRowMap with random keys") {
229-
val taskMemoryManager = new TaskMemoryManager(
230-
new UnifiedMemoryManager(
231-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
232-
Long.MaxValue,
233-
Long.MaxValue / 2,
234-
1),
235-
0)
236223
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
237224

238225
val N = 1000000
239226
val rand = new Random
240227
val keys = (0 to N).map(x => rand.nextLong()).toArray
241228

242-
val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
229+
val map = new LongToUnsafeRowMap(mm, 10)
243230
keys.foreach { k =>
244231
map.append(k, unsafeProj(InternalRow(k)))
245232
}
@@ -249,8 +236,9 @@ class HashedRelationSuite extends SharedSparkSession {
249236
val out = new ObjectOutputStream(os)
250237
map.writeExternal(out)
251238
out.flush()
239+
map.free()
252240
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
253-
val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
241+
val map2 = new LongToUnsafeRowMap(mm, 1)
254242
map2.readExternal(in)
255243

256244
val row = unsafeProj(InternalRow(0L)).copy()
@@ -276,19 +264,12 @@ class HashedRelationSuite extends SharedSparkSession {
276264
}
277265
i += 1
278266
}
279-
map.free()
267+
map2.free()
280268
}
281269

282270
test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
283-
val taskMemoryManager = new TaskMemoryManager(
284-
new UnifiedMemoryManager(
285-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
286-
Long.MaxValue,
287-
Long.MaxValue / 2,
288-
1),
289-
0)
290271
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
291-
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
272+
val map = new LongToUnsafeRowMap(mm, 1)
292273

293274
val key = 0L
294275
// the page array is initialized with length 1 << 17 (1M bytes),
@@ -343,6 +324,7 @@ class HashedRelationSuite extends SharedSparkSession {
343324
val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy())
344325
val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
345326
val longRelation2 = ser.deserialize[LongHashedRelation](ser.serialize(longRelation))
327+
longRelation.close()
346328
(0 until 100).foreach { i =>
347329
val rows = longRelation2.get(i).toArray
348330
assert(rows.length === 2)
@@ -359,6 +341,7 @@ class HashedRelationSuite extends SharedSparkSession {
359341
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
360342
out.flush()
361343
val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
344+
unsafeHashed.close()
362345
val os2 = new ByteArrayOutputStream()
363346
val out2 = new ObjectOutputStream(os2)
364347
unsafeHashed2.writeExternal(out2)
@@ -398,6 +381,7 @@ class HashedRelationSuite extends SharedSparkSession {
398381
thread2.join()
399382

400383
val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
384+
unsafeHashed.close()
401385
val os2 = new ByteArrayOutputStream()
402386
val out2 = new ObjectOutputStream(os2)
403387
unsafeHashed2.writeExternal(out2)
@@ -452,18 +436,21 @@ class HashedRelationSuite extends SharedSparkSession {
452436
val hashedRelation = UnsafeHashedRelation(contiguousRows.iterator, singleKey, 1, mm)
453437
val keyIterator = hashedRelation.keys()
454438
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
439+
hashedRelation.close()
455440
}
456441

457442
test("UnsafeHashedRelation: key set iterator on a sparse array of keys") {
458443
val hashedRelation = UnsafeHashedRelation(sparseRows.iterator, singleKey, 1, mm)
459444
val keyIterator = hashedRelation.keys()
460445
assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
446+
hashedRelation.close()
461447
}
462448

463449
test("LongHashedRelation: key set iterator on a contiguous array of keys") {
464450
val longRelation = LongHashedRelation(contiguousRows.iterator, singleKey, 1, mm)
465451
val keyIterator = longRelation.keys()
466452
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
453+
longRelation.close()
467454
}
468455

469456
test("LongToUnsafeRowMap: key set iterator on a contiguous array of keys") {
@@ -478,6 +465,7 @@ class HashedRelationSuite extends SharedSparkSession {
478465
rowMap.optimize()
479466
keyIterator = rowMap.keys()
480467
assert(keyIterator.map(key => key.getLong(0)).toArray === contiguousArray)
468+
rowMap.free()
481469
}
482470

483471
test("LongToUnsafeRowMap: key set iterator on a sparse array with equidistant keys") {
@@ -490,6 +478,7 @@ class HashedRelationSuite extends SharedSparkSession {
490478
rowMap.optimize()
491479
keyIterator = rowMap.keys()
492480
assert(keyIterator.map(_.getLong(0)).toArray === sparseArray)
481+
rowMap.free()
493482
}
494483

495484
test("LongToUnsafeRowMap: key set iterator on an array with a single key") {
@@ -530,6 +519,7 @@ class HashedRelationSuite extends SharedSparkSession {
530519
buffer.append(keyIterator.next().getLong(0))
531520
}
532521
assert(buffer === randomArray)
522+
rowMap.free()
533523
}
534524

535525
test("LongToUnsafeRowMap: no explicit hasNext calls on the key iterator") {
@@ -560,6 +550,7 @@ class HashedRelationSuite extends SharedSparkSession {
560550
buffer.append(keyIterator.next().getLong(0))
561551
}
562552
assert(buffer === randomArray)
553+
rowMap.free()
563554
}
564555

565556
test("LongToUnsafeRowMap: call hasNext at the end of the iterator") {
@@ -577,6 +568,7 @@ class HashedRelationSuite extends SharedSparkSession {
577568
assert(keyIterator.map(key => key.getLong(0)).toArray === sparseArray)
578569
assert(keyIterator.hasNext == false)
579570
assert(keyIterator.hasNext == false)
571+
rowMap.free()
580572
}
581573

582574
test("LongToUnsafeRowMap: random sequence of hasNext and next() calls on the key iterator") {
@@ -607,6 +599,7 @@ class HashedRelationSuite extends SharedSparkSession {
607599
}
608600
}
609601
assert(buffer === randomArray)
602+
rowMap.free()
610603
}
611604

612605
test("HashJoin: packing and unpacking with the same key type in a LongType") {
@@ -661,6 +654,7 @@ class HashedRelationSuite extends SharedSparkSession {
661654
assert(hashed.keys().isEmpty)
662655
assert(hashed.keyIsUnique)
663656
assert(hashed.estimatedSize == 0)
657+
hashed.close()
664658
}
665659

666660
test("SPARK-32399: test methods related to key index") {
@@ -739,20 +733,14 @@ class HashedRelationSuite extends SharedSparkSession {
739733
val actualValues = row.map(_._2.getInt(1))
740734
assert(actualValues === expectedValues)
741735
}
736+
unsafeRelation.close()
742737
}
743738

744739
test("LongToUnsafeRowMap support ignoresDuplicatedKey") {
745-
val taskMemoryManager = new TaskMemoryManager(
746-
new UnifiedMemoryManager(
747-
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
748-
Long.MaxValue,
749-
Long.MaxValue / 2,
750-
1),
751-
0)
752740
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
753741
val keys = Seq(1L, 1L, 1L)
754742
Seq(true, false).foreach { ignoresDuplicatedKey =>
755-
val map = new LongToUnsafeRowMap(taskMemoryManager, 1, ignoresDuplicatedKey)
743+
val map = new LongToUnsafeRowMap(mm, 1, ignoresDuplicatedKey)
756744
keys.foreach { k =>
757745
map.append(k, unsafeProj(InternalRow(k)))
758746
}

0 commit comments

Comments
 (0)