Skip to content

Commit ad7827c

Browse files
authored
Replace wrapTransversable generators to prevent memory leaks (#2709)
The generator hold a circular reference to the iterator instance.
1 parent 4d2e51a commit ad7827c

File tree

5 files changed

+69
-88
lines changed

5 files changed

+69
-88
lines changed

lib/Doctrine/ODM/MongoDB/Iterator/CachingIterator.php

+16-33
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
namespace Doctrine\ODM\MongoDB\Iterator;
66

77
use Countable;
8-
use Generator;
8+
use Iterator as SPLIterator;
9+
use IteratorIterator;
910
use ReturnTypeWillChange;
1011
use RuntimeException;
1112
use Traversable;
@@ -33,13 +34,11 @@ final class CachingIterator implements Countable, Iterator
3334
/** @var array<mixed, TValue> */
3435
private array $items = [];
3536

36-
/** @var Generator<mixed, TValue>|null */
37-
private ?Generator $iterator;
37+
/** @var SPLIterator<mixed, TValue>|null */
38+
private ?SPLIterator $iterator;
3839

3940
private bool $iteratorAdvanced = false;
4041

41-
private bool $iteratorExhausted = false;
42-
4342
/**
4443
* Initialize the iterator and stores the first item in the cache. This
4544
* effectively rewinds the Traversable and the wrapping Generator, which
@@ -51,7 +50,8 @@ final class CachingIterator implements Countable, Iterator
5150
*/
5251
public function __construct(Traversable $iterator)
5352
{
54-
$this->iterator = $this->wrapTraversable($iterator);
53+
$this->iterator = new IteratorIterator($iterator);
54+
$this->iterator->rewind();
5555
$this->storeCurrentItem();
5656
}
5757

@@ -94,9 +94,10 @@ public function key()
9494
/** @see http://php.net/iterator.next */
9595
public function next(): void
9696
{
97-
if (! $this->iteratorExhausted) {
98-
$this->getIterator()->next();
97+
if ($this->iterator !== null) {
98+
$this->iterator->next();
9999
$this->storeCurrentItem();
100+
$this->iteratorAdvanced = true;
100101
}
101102

102103
next($this->items);
@@ -126,15 +127,13 @@ public function valid(): bool
126127
*/
127128
private function exhaustIterator(): void
128129
{
129-
while (! $this->iteratorExhausted) {
130+
while ($this->iterator !== null) {
130131
$this->next();
131132
}
132-
133-
$this->iterator = null;
134133
}
135134

136-
/** @return Generator<mixed, TValue> */
137-
private function getIterator(): Generator
135+
/** @return SPLIterator<mixed, TValue> */
136+
private function getIterator(): SPLIterator
138137
{
139138
if ($this->iterator === null) {
140139
throw new RuntimeException('Iterator has already been destroyed');
@@ -148,28 +147,12 @@ private function getIterator(): Generator
148147
*/
149148
private function storeCurrentItem(): void
150149
{
151-
$key = $this->getIterator()->key();
150+
$key = $this->iterator->key();
152151

153152
if ($key === null) {
154-
return;
153+
$this->iterator = null;
154+
} else {
155+
$this->items[$key] = $this->getIterator()->current();
155156
}
156-
157-
$this->items[$key] = $this->getIterator()->current();
158-
}
159-
160-
/**
161-
* @param Traversable<mixed, TValue> $traversable
162-
*
163-
* @return Generator<mixed, TValue>
164-
*/
165-
private function wrapTraversable(Traversable $traversable): Generator
166-
{
167-
foreach ($traversable as $key => $value) {
168-
yield $key => $value;
169-
170-
$this->iteratorAdvanced = true;
171-
}
172-
173-
$this->iteratorExhausted = true;
174157
}
175158
}

lib/Doctrine/ODM/MongoDB/Iterator/HydratingIterator.php

+7-18
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
use Doctrine\ODM\MongoDB\Mapping\ClassMetadata;
88
use Doctrine\ODM\MongoDB\UnitOfWork;
9-
use Generator;
109
use Iterator;
10+
use IteratorIterator;
1111
use ReturnTypeWillChange;
1212
use RuntimeException;
1313
use Traversable;
@@ -24,8 +24,8 @@
2424
*/
2525
final class HydratingIterator implements Iterator
2626
{
27-
/** @var Generator<mixed, array<string, mixed>>|null */
28-
private ?Generator $iterator;
27+
/** @var Iterator<mixed, array<string, mixed>>|null */
28+
private ?Iterator $iterator;
2929

3030
/**
3131
* @param Traversable<mixed, array<string, mixed>> $traversable
@@ -34,7 +34,8 @@ final class HydratingIterator implements Iterator
3434
*/
3535
public function __construct(Traversable $traversable, private UnitOfWork $unitOfWork, private ClassMetadata $class, private array $unitOfWorkHints = [])
3636
{
37-
$this->iterator = $this->wrapTraversable($traversable);
37+
$this->iterator = new IteratorIterator($traversable);
38+
$this->iterator->rewind();
3839
}
3940

4041
public function __destruct()
@@ -74,8 +75,8 @@ public function valid(): bool
7475
return $this->key() !== null;
7576
}
7677

77-
/** @return Generator<mixed, array<string, mixed>> */
78-
private function getIterator(): Generator
78+
/** @return Iterator<mixed, array<string, mixed>> */
79+
private function getIterator(): Iterator
7980
{
8081
if ($this->iterator === null) {
8182
throw new RuntimeException('Iterator has already been destroyed');
@@ -93,16 +94,4 @@ private function hydrate(?array $document): ?object
9394
{
9495
return $document !== null ? $this->unitOfWork->getOrCreateDocument($this->class->name, $document, $this->unitOfWorkHints) : null;
9596
}
96-
97-
/**
98-
* @param Traversable<mixed, array<string, mixed>> $traversable
99-
*
100-
* @return Generator<mixed, array<string, mixed>>
101-
*/
102-
private function wrapTraversable(Traversable $traversable): Generator
103-
{
104-
foreach ($traversable as $key => $value) {
105-
yield $key => $value;
106-
}
107-
}
10897
}

lib/Doctrine/ODM/MongoDB/Iterator/UnrewindableIterator.php

+24-37
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
namespace Doctrine\ODM\MongoDB\Iterator;
66

7-
use Generator;
7+
use Iterator as SPLIterator;
8+
use IteratorIterator;
89
use LogicException;
910
use ReturnTypeWillChange;
1011
use RuntimeException;
@@ -23,39 +24,34 @@
2324
*/
2425
final class UnrewindableIterator implements Iterator
2526
{
26-
/** @var Generator<mixed, TValue>|null */
27-
private ?Generator $iterator;
27+
/** @var SPLIterator<mixed, TValue>|null */
28+
private ?SPLIterator $iterator;
2829

2930
private bool $iteratorAdvanced = false;
3031

3132
/**
32-
* Initialize the iterator. This effectively rewinds the Traversable and
33-
* the wrapping Generator, which will execute up to its first yield statement.
34-
* Additionally, this mimics behavior of the SPL iterators and allows users
35-
* to omit an explicit call to rewind() before using the other methods.
33+
* Initialize the iterator. This effectively rewinds the Traversable.
34+
* This mimics behavior of the SPL iterators and allows users to omit an
35+
* explicit call to rewind() before using the other methods.
3636
*
3737
* @param Traversable<mixed, TValue> $iterator
3838
*/
3939
public function __construct(Traversable $iterator)
4040
{
41-
$this->iterator = $this->wrapTraversable($iterator);
42-
$this->iterator->key();
41+
$this->iterator = new IteratorIterator($iterator);
42+
$this->iterator->rewind();
4343
}
4444

4545
public function toArray(): array
4646
{
4747
$this->preventRewinding(__METHOD__);
4848

49-
$toArray = function () {
50-
if (! $this->valid()) {
51-
return;
52-
}
53-
54-
yield $this->key() => $this->current();
55-
yield from $this->getIterator();
56-
};
57-
58-
return iterator_to_array($toArray());
49+
try {
50+
return iterator_to_array($this->getIterator());
51+
} finally {
52+
$this->iteratorAdvanced = true;
53+
$this->iterator = null;
54+
}
5955
}
6056

6157
/** @return TValue|null */
@@ -84,6 +80,13 @@ public function next(): void
8480
}
8581

8682
$this->iterator->next();
83+
$this->iteratorAdvanced = true;
84+
85+
if ($this->iterator->valid()) {
86+
return;
87+
}
88+
89+
$this->iterator = null;
8790
}
8891

8992
/** @see http://php.net/iterator.rewind */
@@ -108,29 +111,13 @@ private function preventRewinding(string $method): void
108111
}
109112
}
110113

111-
/** @return Generator<mixed, TValue> */
112-
private function getIterator(): Generator
114+
/** @return SPLIterator<mixed, TValue> */
115+
private function getIterator(): SPLIterator
113116
{
114117
if ($this->iterator === null) {
115118
throw new RuntimeException('Iterator has already been destroyed');
116119
}
117120

118121
return $this->iterator;
119122
}
120-
121-
/**
122-
* @param Traversable<mixed, TValue> $traversable
123-
*
124-
* @return Generator<mixed, TValue>
125-
*/
126-
private function wrapTraversable(Traversable $traversable): Generator
127-
{
128-
foreach ($traversable as $key => $value) {
129-
yield $key => $value;
130-
131-
$this->iteratorAdvanced = true;
132-
}
133-
134-
$this->iterator = null;
135-
}
136123
}

tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/CachingIteratorTest.php

+13
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ public function testIterationWithEmptySet(): void
5959
self::assertFalse($iterator->valid());
6060
}
6161

62+
public function testIterationWithInvalidIterator(): void
63+
{
64+
$mock = $this->createMock(Iterator::class);
65+
// The method next() should not be called on a dead cursor.
66+
$mock->expects(self::never())->method('next');
67+
// The method valid() return false on a dead cursor.
68+
$mock->expects(self::once())->method('valid')->willReturn(false);
69+
70+
$iterator = new CachingIterator($mock);
71+
72+
$this->assertEquals([], $iterator->toArray());
73+
}
74+
6275
public function testPartialIterationDoesNotExhaust(): void
6376
{
6477
$traversable = $this->getTraversableThatThrows([1, 2, new Exception()]);

tests/Doctrine/ODM/MongoDB/Tests/Functional/Iterator/UnrewindableIteratorTest.php

+9
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,15 @@ public function testRewindAfterPartialIteration(): void
100100
iterator_to_array($iterator);
101101
}
102102

103+
public function testRewindAfterToArray(): void
104+
{
105+
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));
106+
107+
$iterator->toArray();
108+
$this->expectException(LogicException::class);
109+
$iterator->rewind();
110+
}
111+
103112
public function testToArray(): void
104113
{
105114
$iterator = new UnrewindableIterator($this->getTraversable([1, 2, 3]));

0 commit comments

Comments
 (0)