Skip to content
16 changes: 15 additions & 1 deletion dvc_webdav/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from dvc.utils.objects import cached_property
from dvc_objects.fs.base import FileSystem
from dvc_webdav.bearer_auth_client import BearerAuthClient

logger = logging.getLogger(__name__)
logger = logging.getLogger("dvc")


@wrap_with(threading.Lock())
Expand All @@ -17,6 +18,16 @@ def ask_password(host, user):
return getpass(f"Enter a password for host '{host}' user '{user}':\n")


@wrap_with(threading.Lock())
@memoize
def get_bearer_auth_client(bearer_token_command: str) -> BearerAuthClient:
logger.debug(
"Bearer token command provided, using BearerAuthClient, command: %s",
bearer_token_command,
)
return BearerAuthClient(bearer_token_command)


class WebDAVFileSystem(FileSystem): # pylint:disable=abstract-method
protocol = "webdav"
root_marker = ""
Expand All @@ -37,6 +48,9 @@ def __init__(self, **config):
"timeout": config.get("timeout", 30),
}
)
if bearer_token_command := config.get("bearer_token_command"):
client = get_bearer_auth_client(bearer_token_command)
self.fs_args["http_client"] = client

def unstrip_protocol(self, path: str) -> str:
return self.fs_args["base_url"] + "/" + path
Expand Down
161 changes: 161 additions & 0 deletions dvc_webdav/bearer_auth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import logging
import shlex
import subprocess
import sys
import threading
from typing import Optional, Union

import httpx

logger = logging.getLogger("dvc")


def _log_with_thread(level: int, msg: str, *args) -> None:
"""
Universal helper to inject thread identity into logs.
Output format: [Thread-Name] Message...
"""
if logger.isEnabledFor(level):
thread_name = threading.current_thread().name
log_fmt = f"[{thread_name}] " + msg
logger.log(level, log_fmt, *args)


def execute_command(command: Union[list[str], str], timeout: int = 10) -> str:
"""Executes a command to retrieve the token."""
if isinstance(command, str):
command = shlex.split(command)

try:
result = subprocess.run( # noqa: S603
command,
shell=False,
capture_output=True,
text=True,
check=True,
timeout=timeout,
encoding="utf-8",
)
except (
FileNotFoundError,
subprocess.TimeoutExpired,
subprocess.CalledProcessError,
ValueError,
OSError,
) as e:
error_header = "\n" + "=" * 60
error_msg = (
f"{error_header}\n[CRITICAL] Bearer Token Retrieval Failed.\n"
"DVC may misinterpret this as 'File Not Found' and skip files.\n"
f"Command: {command}\n"
f"Error: {e}"
)

if isinstance(e, subprocess.CalledProcessError):
error_msg += f"\nStderr: {e.stderr.strip()}"

error_msg += f"\n{error_header}\n"

logger.critical(error_msg)
sys.stderr.write(error_msg)
sys.stderr.flush()

# Re-raise the exception so the caller knows it failed.
# DVC might catch this and swallow it, but we've done our duty to notify.
raise

token = result.stdout.strip()
if not token:
raise ValueError("Command executed successfully but returned an empty token.")
return token


class BearerAuthClient(httpx.Client):
"""HTTPX client that adds Bearer token authentication using a command.

Args:
bearer_token_command: The command to run to get the Bearer token.
**kwargs: Additional arguments to pass to the httpx.Client constructor.
"""

def __init__(
self,
bearer_token_command: str,
**kwargs,
):
super().__init__(**kwargs)
if (
not isinstance(bearer_token_command, str)
or not bearer_token_command.strip()
):
raise ValueError(
"[BearerAuthClient] bearer_token_command must be a non-empty string"
)
self.bearer_token_command = bearer_token_command
self._token: Optional[str] = None
self._lock = threading.Lock()

def _refresh_token(self) -> None:
"""Execute token command and update state."""
_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Refreshing token via command..."
)

try:
new_token = execute_command(self.bearer_token_command)
# execute_command guarantees non-empty string or raises ValueError

self._token = new_token
self.headers["Authorization"] = f"Bearer {new_token}"

_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Token refreshed successfully."
)
except Exception:
# Clean up state on failure
self._token = None
raise

def _ensure_token(self) -> None:
"""Ensure a token exists before making requests"""
if self._token:
return

with self._lock:
if not self._token:
self._refresh_token()

def request(self, *args, **kwargs) -> httpx.Response:
"""Wraps httpx.request with auto-refresh logic for 401 Unauthorized."""
self._ensure_token()
response = super().request(*args, **kwargs)

if response.status_code != 401:
return response

_log_with_thread(
logging.DEBUG, "[BearerAuthClient] Received 401. Attempting recovery."
)
sent_auth_header = response.request.headers.get("Authorization")

try:
with self._lock:
current_auth_header = self.headers.get("Authorization")
if sent_auth_header == current_auth_header:
self._refresh_token()
else:
_log_with_thread(
logging.DEBUG,
"[BearerAuthClient] Token already refreshed by another thread. "
"Retrying.",
)
except Exception:
logger.exception(
"[BearerAuthClient] Recovery failed: Token refresh threw exception"
)
return response

# Retry the request with the new valid token
# We must close the old 401 response to free connections
response.close()
return super().request(*args, **kwargs)