Skip to content

Commit 53e4b32

Browse files
authored
lfs: add tests for rate limit retries (#341)
1 parent 5d237d7 commit 53e4b32

File tree

3 files changed

+212
-3
lines changed

3 files changed

+212
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Source = "https://github.com/iterative/scmrepo"
3939

4040
[project.optional-dependencies]
4141
tests = [
42+
"aioresponses==0.7.6",
4243
"pytest==8.1.1",
4344
"pytest-sugar==1.0.0",
4445
"pytest-cov==4.1.0",

src/scmrepo/git/lfs/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def loop(self):
8686
def from_git_url(cls, git_url: str) -> "LFSClient":
8787
if git_url.startswith(("ssh://", "git@")):
8888
return _SSHLFSClient.from_git_url(git_url)
89-
if git_url.startswith("https://"):
89+
if git_url.startswith(("http://", "https://")):
9090
return _HTTPLFSClient.from_git_url(git_url)
9191
raise NotImplementedError(f"Unsupported Git URL: {git_url}")
9292

tests/test_lfs.py

Lines changed: 210 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
# pylint: disable=redefined-outer-name
22
import io
3+
from collections import defaultdict
4+
from collections.abc import Sequence
5+
from http import HTTPStatus
6+
from time import time
7+
from typing import Callable
38

49
import pytest
10+
from aiohttp import ClientResponseError
11+
from aioresponses import CallbackResult, aioresponses
512
from pytest_mock import MockerFixture
613
from pytest_test_utils import TempDirFactory, TmpDir
14+
from yarl import URL
715

816
from scmrepo.git import Git
9-
from scmrepo.git.lfs import LFSStorage, Pointer, smudge
17+
from scmrepo.git.lfs import LFSClient, LFSStorage, Pointer, smudge
1018

1119
FOO_OID = "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae"
20+
FOO_SIZE = 3
1221
FOO_POINTER = (
13-
f"version https://git-lfs.github.com/spec/v1\noid sha256:{FOO_OID}\nsize 3\n"
22+
f"version https://git-lfs.github.com/spec/v1\n"
23+
f"oid sha256:{FOO_OID}\n"
24+
f"size {FOO_SIZE}\n"
1425
).encode()
1526

1627

@@ -74,3 +85,200 @@ def test_lfs(tmp_dir: TmpDir, scm: Git, lfs_objects: TmpDir):
7485
assert fobj.read() == FOO_POINTER
7586
with fs.open("foo.lfs", "rb", raw=False) as fobj:
7687
assert fobj.read() == b"foo"
88+
89+
90+
class CallbackResultRecorder:
91+
def __init__(self) -> None:
92+
self._results: dict[str, list[CallbackResult]] = defaultdict(list)
93+
94+
def record(self, result: CallbackResult) -> Callable[..., CallbackResult]:
95+
def _callback(url: URL, **_) -> CallbackResult:
96+
self._results[str(url)].append(result)
97+
return result
98+
99+
return _callback
100+
101+
def __getitem__(self, url: str) -> Sequence[CallbackResult]:
102+
return self._results[url]
103+
104+
105+
class LFSServerMock:
106+
def __init__(
107+
self,
108+
mocker: aioresponses,
109+
recorder: CallbackResultRecorder,
110+
batch_url: str,
111+
objects_url: str,
112+
) -> None:
113+
self._mocker = mocker
114+
self._recorder = recorder
115+
self.batch_url = batch_url
116+
self.objects_url = objects_url
117+
118+
def get_object_url(self, oid: str) -> str:
119+
return f"{self.objects_url}/{oid}"
120+
121+
def mock_batch_200(self, pointer: Pointer) -> None:
122+
self._mocker.post(
123+
self.batch_url,
124+
callback=self._recorder.record(
125+
CallbackResult(
126+
status=HTTPStatus.OK,
127+
headers={"Content-Type": "application/vnd.git-lfs+json"},
128+
payload={
129+
"transfer": "basic",
130+
"objects": [
131+
{
132+
"oid": pointer.oid,
133+
"size": pointer.size,
134+
"authenticated": True,
135+
"actions": {
136+
"download": {
137+
"href": self.get_object_url(pointer.oid),
138+
}
139+
},
140+
}
141+
],
142+
"hash_algo": "sha256",
143+
},
144+
)
145+
),
146+
)
147+
148+
def mock_batch_429(
149+
self, header: str, value: Callable[[], str], *, repeat: bool = False
150+
) -> None:
151+
self._mocker.post(
152+
self.batch_url,
153+
callback=self._recorder.record(
154+
CallbackResult(
155+
status=HTTPStatus.TOO_MANY_REQUESTS,
156+
headers={header: value()},
157+
reason="Too many requests",
158+
)
159+
),
160+
repeat=repeat,
161+
)
162+
163+
def mock_object_200(self, oid: str) -> None:
164+
self._mocker.get(
165+
self.get_object_url(oid),
166+
callback=self._recorder.record(
167+
CallbackResult(
168+
status=HTTPStatus.OK,
169+
body=f"object {oid} data",
170+
)
171+
),
172+
)
173+
174+
def mock_object_429(
175+
self,
176+
oid: str,
177+
header: str,
178+
value: Callable[[], str],
179+
*,
180+
repeat: bool = False,
181+
) -> None:
182+
self._mocker.get(
183+
self.get_object_url(oid),
184+
callback=self._recorder.record(
185+
CallbackResult(
186+
status=HTTPStatus.TOO_MANY_REQUESTS,
187+
headers={header: value()},
188+
reason="Too many requests",
189+
)
190+
),
191+
repeat=repeat,
192+
)
193+
194+
195+
@pytest.mark.parametrize(
196+
"rate_limit_header, rate_limit_value",
197+
[
198+
("Retry-After", lambda: "1"),
199+
("RateLimit-Reset", lambda: f"{int(time()) + 1}"),
200+
("X-RateLimit-Reset", lambda: f"{int(time()) + 1}"),
201+
],
202+
)
203+
def test_rate_limit_retry(
204+
storage: LFSStorage, rate_limit_header: str, rate_limit_value: Callable[[], str]
205+
):
206+
client = LFSClient.from_git_url("http://git.example.com/namespace/project.git")
207+
recorder = CallbackResultRecorder()
208+
209+
with aioresponses() as m:
210+
lfs_server = LFSServerMock(
211+
m, recorder, f"{client.url}/objects/batch", "http://git-lfs.example.com"
212+
)
213+
lfs_server.mock_batch_429(rate_limit_header, rate_limit_value)
214+
lfs_server.mock_batch_200(Pointer(FOO_OID, FOO_SIZE))
215+
lfs_server.mock_object_429(FOO_OID, rate_limit_header, rate_limit_value)
216+
lfs_server.mock_object_200(FOO_OID)
217+
218+
client.download(storage, [Pointer(oid=FOO_OID, size=FOO_SIZE)])
219+
220+
results = recorder[lfs_server.batch_url]
221+
assert [r.status for r in results] == [429, 200]
222+
223+
results = recorder[lfs_server.get_object_url(FOO_OID)]
224+
assert [r.status for r in results] == [429, 200]
225+
226+
227+
@pytest.mark.parametrize(
228+
"rate_limit_header, rate_limit_value",
229+
[
230+
("Retry-After", lambda: "1"),
231+
("RateLimit-Reset", lambda: f"{int(time()) + 1}"),
232+
("X-RateLimit-Reset", lambda: f"{int(time()) + 1}"),
233+
],
234+
)
235+
def test_rate_limit_max_retries_batch(
236+
storage: LFSStorage, rate_limit_header: str, rate_limit_value: Callable[[], str]
237+
):
238+
client = LFSClient.from_git_url("http://git.example.com/namespace/project.git")
239+
recorder = CallbackResultRecorder()
240+
241+
with aioresponses() as m:
242+
lfs_server = LFSServerMock(
243+
m, recorder, f"{client.url}/objects/batch", "http://git-lfs.example.com"
244+
)
245+
lfs_server.mock_batch_429(rate_limit_header, rate_limit_value, repeat=True)
246+
247+
with pytest.raises(ClientResponseError, match="Too many requests"):
248+
client.download(storage, [Pointer(oid=FOO_OID, size=FOO_SIZE)])
249+
250+
results = recorder[lfs_server.batch_url]
251+
assert [r.status for r in results] == [429] * 5
252+
253+
254+
@pytest.mark.parametrize(
255+
"rate_limit_header, rate_limit_value",
256+
[
257+
("Retry-After", lambda: "1"),
258+
("RateLimit-Reset", lambda: f"{int(time()) + 1}"),
259+
("X-RateLimit-Reset", lambda: f"{int(time()) + 1}"),
260+
],
261+
)
262+
def test_rate_limit_max_retries_objects(
263+
storage: LFSStorage, rate_limit_header: str, rate_limit_value: Callable[[], str]
264+
):
265+
client = LFSClient.from_git_url("http://git.example.com/namespace/project.git")
266+
recorder = CallbackResultRecorder()
267+
268+
with aioresponses() as m:
269+
lfs_server = LFSServerMock(
270+
m, recorder, f"{client.url}/objects/batch", "http://git-lfs.example.com"
271+
)
272+
lfs_server.mock_batch_200(Pointer(FOO_OID, FOO_SIZE))
273+
lfs_server.mock_object_429(
274+
FOO_OID, rate_limit_header, rate_limit_value, repeat=True
275+
)
276+
277+
with pytest.raises(ClientResponseError, match="Too many requests"):
278+
client.download(storage, [Pointer(oid=FOO_OID, size=FOO_SIZE)])
279+
280+
results = recorder[lfs_server.batch_url]
281+
assert [r.status for r in results] == [200]
282+
283+
results = recorder[lfs_server.get_object_url(FOO_OID)]
284+
assert [r.status for r in results] == [429] * 5

0 commit comments

Comments
 (0)