1+ import json
12import logging
23import os
4+ import re
35import shutil
6+ from abc import abstractmethod
47from collections .abc import Iterable , Iterator
58from contextlib import AbstractContextManager , contextmanager , suppress
69from tempfile import NamedTemporaryFile
1316from fsspec .implementations .http import HTTPFileSystem
1417from funcy import cached_property
1518
19+ from scmrepo .git .backend .dulwich import _get_ssh_vendor
1620from scmrepo .git .credentials import Credential , CredentialNotFoundError
1721
1822from .exceptions import LFSError
@@ -35,19 +39,12 @@ class LFSClient(AbstractContextManager):
3539 _SESSION_RETRIES = 5
3640 _SESSION_BACKOFF_FACTOR = 0.1
3741
38- def __init__ (
39- self ,
40- url : str ,
41- git_url : Optional [str ] = None ,
42- headers : Optional [dict [str , str ]] = None ,
43- ):
42+ def __init__ (self , url : str ):
4443 """
4544 Args:
4645 url: LFS server URL.
4746 """
4847 self .url = url
49- self .git_url = git_url
50- self .headers : dict [str , str ] = headers or {}
5148
5249 def __exit__ (self , * args , ** kwargs ):
5350 self .close ()
@@ -84,23 +81,18 @@ def loop(self):
8481
8582 @classmethod
8683 def from_git_url (cls , git_url : str ) -> "LFSClient" :
87- if git_url .endswith ( ". git" ):
88- url = f" { git_url } /info/lfs"
89- else :
90- url = f" { git_url } .git/info/lfs"
91- return cls ( url , git_url = git_url )
84+ if git_url .startswith (( "ssh://" , " git@" ) ):
85+ return _SSHLFSClient . from_git_url ( git_url )
86+ if git_url . startswith ( "https://" ) :
87+ return _HTTPLFSClient . from_git_url ( git_url )
88+ raise NotImplementedError ( f"Unsupported Git URL: { git_url } " )
9289
9390 def close (self ):
9491 pass
9592
96- def _get_auth (self ) -> Optional [aiohttp .BasicAuth ]:
97- try :
98- creds = Credential (url = self .git_url ).fill ()
99- if creds .username and creds .password :
100- return aiohttp .BasicAuth (creds .username , creds .password )
101- except CredentialNotFoundError :
102- pass
103- return None
93+ @abstractmethod
94+ def _get_auth_header (self , * , upload : bool ) -> dict :
95+ ...
10496
10597 async def _batch_request (
10698 self ,
@@ -120,9 +112,10 @@ async def _batch_request(
120112 if ref :
121113 body ["ref" ] = [{"name" : ref }]
122114 session = await self ._fs .set_session ()
123- headers = dict (self .headers )
124- headers ["Accept" ] = self .JSON_CONTENT_TYPE
125- headers ["Content-Type" ] = self .JSON_CONTENT_TYPE
115+ headers = {
116+ "Accept" : self .JSON_CONTENT_TYPE ,
117+ "Content-Type" : self .JSON_CONTENT_TYPE ,
118+ }
126119 try :
127120 async with session .post (
128121 url ,
@@ -134,13 +127,12 @@ async def _batch_request(
134127 except aiohttp .ClientResponseError as exc :
135128 if exc .status != 401 :
136129 raise
137- auth = self ._get_auth ( )
138- if auth is None :
130+ auth_header = self ._get_auth_header ( upload = upload )
131+ if not auth_header :
139132 raise
140133 async with session .post (
141134 url ,
142- auth = auth ,
143- headers = headers ,
135+ headers = {** headers , ** auth_header },
144136 json = body ,
145137 raise_for_status = True ,
146138 ) as resp :
@@ -186,6 +178,85 @@ async def _get_one(from_path: str, to_path: str, **kwargs):
186178 download = sync_wrapper (_download )
187179
188180
181+ class _HTTPLFSClient (LFSClient ):
182+ def __init__ (self , url : str , git_url : str ):
183+ """
184+ Args:
185+ url: LFS server URL.
186+ git_url: Git HTTP URL.
187+ """
188+ super ().__init__ (url )
189+ self .git_url = git_url
190+
191+ @classmethod
192+ def from_git_url (cls , git_url : str ) -> "_HTTPLFSClient" :
193+ if git_url .endswith (".git" ):
194+ url = f"{ git_url } /info/lfs"
195+ else :
196+ url = f"{ git_url } .git/info/lfs"
197+ return cls (url , git_url = git_url )
198+
199+ def _get_auth_header (self , * , upload : bool ) -> dict :
200+ try :
201+ creds = Credential (url = self .git_url ).fill ()
202+ if creds .username and creds .password :
203+ return {
204+ aiohttp .hdrs .AUTHORIZATION : aiohttp .BasicAuth (
205+ creds .username , creds .password
206+ ).encode ()
207+ }
208+ except CredentialNotFoundError :
209+ pass
210+ return {}
211+
212+
213+ class _SSHLFSClient (LFSClient ):
214+ _URL_PATTERN = re .compile (
215+ r"(?:ssh://)?git@(?P<host>\S+?)(?::(?P<port>\d+))?(?:[:/])(?P<path>\S+?)\.git"
216+ )
217+
218+ def __init__ (self , url : str , host : str , port : int , path : str ):
219+ """
220+ Args:
221+ url: LFS server URL.
222+ host: Git SSH server host.
223+ port: Git SSH server port.
224+ path: Git project path.
225+ """
226+ super ().__init__ (url )
227+ self .host = host
228+ self .port = port
229+ self .path = path
230+ self ._ssh = _get_ssh_vendor ()
231+
232+ @classmethod
233+ def from_git_url (cls , git_url : str ) -> "_SSHLFSClient" :
234+ result = cls ._URL_PATTERN .match (git_url )
235+ if not result :
236+ raise ValueError (f"Invalid Git SSH URL: { git_url } " )
237+ host , port , path = result .group ("host" , "port" , "path" )
238+ url = f"https://{ host } /{ path } .git/info/lfs"
239+ return cls (url , host , int (port or 22 ), path )
240+
241+ def _get_auth_header (self , * , upload : bool ) -> dict :
242+ return self ._git_lfs_authenticate (
243+ self .host , self .port , f"{ self .path } .git" , upload = upload
244+ ).get ("header" , {})
245+
246+ def _git_lfs_authenticate (
247+ self , host : str , port : int , path : str , * , upload : bool = False
248+ ) -> dict :
249+ action = "upload" if upload else "download"
250+ return json .loads (
251+ self ._ssh .run_command (
252+ command = f"git-lfs-authenticate { path } { action } " ,
253+ host = host ,
254+ port = port ,
255+ username = "git" ,
256+ ).read ()
257+ )
258+
259+
189260@contextmanager
190261def _as_atomic (to_info : str , create_parents : bool = False ) -> Iterator [str ]:
191262 parent = os .path .dirname (to_info )
0 commit comments