Skip to content

Commit

Permalink
first cut at duration completer, #16
Browse files Browse the repository at this point in the history
  • Loading branch information
bckohan committed Jan 28, 2025
1 parent f0a3b3e commit 7b2b0a7
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 10 deletions.
31 changes: 30 additions & 1 deletion django_typer/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import sys
import typing as t
from datetime import date, time
from datetime import date, time, timedelta
from functools import partial
from pathlib import Path
from types import MethodType
Expand All @@ -35,6 +35,7 @@
DateField,
DateTimeField,
DecimalField,
DurationField,
Field,
FileField,
FilePathField,
Expand Down Expand Up @@ -82,6 +83,7 @@ class ModelObjectCompleter:
- `DateField <https://docs.djangoproject.com/en/stable/ref/models/fields/#datefield>`_ **(Must use ISO 8601: YYYY-MM-DD)**
- `TimeField <https://docs.djangoproject.com/en/stable/ref/models/fields/#timefield>`_ **(Must use ISO 8601: HH:MM:SS.ssssss)**
- `DateTimeField <https://docs.djangoproject.com/en/stable/ref/models/fields/#datetimefield>`_ **(Must use ISO 8601: YYYY-MM-DDTHH:MM:SS.ssssss±HH:MM)**
- `DurationField <https://docs.djangoproject.com/en/stable/ref/models/fields/#durationfield>`_ **(Must use ISO 8601: YYYY-MM-DDTHH:MM:SS.ssssss±HH:MM)**
- `UUIDField <https://docs.djangoproject.com/en/stable/ref/models/fields/#uuidfield>`_
- `FloatField <https://docs.djangoproject.com/en/stable/ref/models/fields/#floatfield>`_
- `DecimalField <https://docs.djangoproject.com/en/stable/ref/models/fields/#decimalfield>`_
Expand Down Expand Up @@ -178,6 +180,10 @@ def to_str(self, obj: t.Any) -> str:
return obj.isoformat()
elif isinstance(obj, date):
return obj.isoformat()
elif isinstance(obj, timedelta):
from django.utils.duration import duration_iso_string

return duration_iso_string(obj)
return str(obj)

def int_query(self, context: Context, parameter: Parameter, incomplete: str) -> Q:
Expand Down Expand Up @@ -475,6 +481,27 @@ def get_tz_part(dt_str: str) -> str:
**{f"{self.lookup_field}__lte": upper_bound}
)

def duration_query(
self, context: Context, parameter: Parameter, incomplete: str
) -> Q:
"""
Default completion query builder for duratioin fields. This method will return a Q object that
will match any value that is greater than the incomplete duraton string (or less if negative).
All durations must be in ISO8601 format (YYYY-MM-DD). Week specifiers are not supported.
:param context: The click context.
:param parameter: The click parameter.
:param incomplete: The incomplete string.
:return: A Q object to use for filtering the queryset.
:raises ValueError: If the incomplete string is not a valid partial duration.
:raises AssertionError: If the incomplete string is not a valid partial duration.
"""
from django_typer.utils import parse_iso_duration

duration = parse_iso_duration(incomplete)
lookup = "gte" if duration >= timedelta() else "lte"
return Q(**{f"{self.lookup_field}__{lookup}": duration})

def __init__(
self,
model_or_qry: t.Union[t.Type[Model], QuerySet],
Expand Down Expand Up @@ -530,6 +557,8 @@ def __init__(
self.query = self.date_query
elif isinstance(self._field, TimeField):
self.query = self.time_query
elif isinstance(self._field, DurationField):
self.query = self.duration_query
else:
raise ValueError(
_("Unsupported lookup field class: {cls}").format(
Expand Down
4 changes: 4 additions & 0 deletions django_typer/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def convert(
value = date.fromisoformat(value)
elif isinstance(self._field, models.TimeField):
value = time.fromisoformat(value)
elif isinstance(self._field, models.DurationField):
from django_typer.utils import parse_iso_duration

value = parse_iso_duration(value)
return self.model_cls.objects.get(
**{f"{self.lookup_field}{self._lookup}": value}
)
Expand Down
52 changes: 46 additions & 6 deletions django_typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@
A collection of useful utilities.
"""

import importlib
import inspect
import os
import pkgutil
import shutil
import sys
import typing as t
from datetime import timedelta
from functools import partial
from pathlib import Path
from threading import local
from types import MethodType, ModuleType

from shellingham import ShellDetectionFailure
from shellingham import detect_shell as _detect_shell

from .config import traceback_config

# DO NOT IMPORT ANYTHING FROM TYPER HERE - SEE patch.py


__all__ = [
"detect_shell",
"get_usage_script",
Expand All @@ -40,6 +36,9 @@ def detect_shell(max_depth: int = 10) -> t.Tuple[str, str]:
:raises ShellDetectionFailure: If the shell cannot be detected
:return: A tuple of the shell name and the shell command
"""
from shellingham import ShellDetectionFailure
from shellingham import detect_shell as _detect_shell

try:
return _detect_shell(max_depth=max_depth)
except ShellDetectionFailure:
Expand All @@ -57,6 +56,8 @@ def get_usage_script(script: t.Optional[str] = None) -> t.Union[Path, str]:
:param script: The script name to check. If None the current script is used.
:return: The script name or the relative path to the script from cwd.
"""
import shutil

cmd_pth = Path(script or sys.argv[0])
on_path: t.Optional[t.Union[str, Path]] = shutil.which(cmd_pth.name)
on_path = on_path and Path(on_path)
Expand Down Expand Up @@ -143,6 +144,8 @@ def ready(self):
:param commands: The names of the commands/modules, if not provided, all modules
in the package will be registered as plugins
"""
import pkgutil

commands = commands or [
module[1].split(".")[-1]
for module in pkgutil.iter_modules(package.__path__, f"{package.__name__}.")
Expand All @@ -163,6 +166,8 @@ def _load_command_plugins(command: str) -> int:
"""
plugins = _command_plugins.get(command, [])
if plugins:
import importlib

for ext_pkg in reversed(plugins):
try:
importlib.import_module(f"{ext_pkg.__name__}.{command}")
Expand Down Expand Up @@ -255,6 +260,7 @@ def get_win_shell() -> str:
:return: The name of the shell, either 'powershell' or 'pwsh'
"""
import json
import shutil
import subprocess

from shellingham import ShellDetectionFailure
Expand Down Expand Up @@ -288,3 +294,37 @@ def get_win_shell() -> str:
raise ShellDetectionFailure("Unable to detect windows shell") from e

raise ShellDetectionFailure("Unable to detect windows shell")


def parse_iso_duration(duration: str) -> timedelta:
"""
Progressively parse an ISO8601 duration type.
"""
import re

# Define regex pattern for ISO 8601 duration
pattern = re.compile(
r"([-+])?"
r"(P" # Start with 'P'
r"(?:(?P<days>\d+)D)?" # Capture days (optional)
r"(?:T" # Start time part (optional)
r"(?:(?P<hours>\d+)H)?" # Capture hours (optional)
r"(?:(?P<minutes>\d+)M)?" # Capture minutes (optional)
r"(?:(?P<seconds>\d+(?:\.\d+)?)S)?)?)?" # Capture seconds (optional, with optional fraction)
)

# Match the input string to the pattern
match = pattern.fullmatch(duration)
if not match:
raise ValueError(f"Invalid ISO 8601 duration format: {duration}")

# Extract matched groups and convert to float/int
days = int(match.group("days")) if match.group("days") else 0
hours = int(match.group("hours")) if match.group("hours") else 0
minutes = int(match.group("minutes")) if match.group("minutes") else 0
seconds = float(match.group("seconds")) if match.group("seconds") else 0.0

td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
if duration.startswith("-"):
return -td
return td
13 changes: 13 additions & 0 deletions tests/apps/test_app/management/commands/model_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ def test(
help=t.cast(str, _("Fetch objects by their time fields.")),
),
] = None,
duration: Annotated[
t.Optional[ShellCompleteTester],
typer.Option(
**model_parser_completer(
ShellCompleteTester, "duration_field", order_by="duration_field"
),
help=t.cast(str, _("Fetch objects by their duration fields.")),
),
] = None,
):
assert self.__class__ is Command
objects = {}
Expand Down Expand Up @@ -252,4 +261,8 @@ def test(
if time is not None:
assert isinstance(time, ShellCompleteTester)
objects["time"] = {time.id: str(time.time_field)}

if duration is not None:
assert isinstance(duration, ShellCompleteTester)
objects["duration"] = {duration.id: str(duration.duration_field)}
return json.dumps(objects)
6 changes: 5 additions & 1 deletion tests/apps/test_app/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 4.2.18 on 2025-01-21 04:09
# Generated by Django 4.2.18 on 2025-01-28 21:51

from django.db import migrations, models

Expand Down Expand Up @@ -77,6 +77,10 @@ class Migration(migrations.Migration):
"time_field",
models.TimeField(db_index=True, default=None, null=True),
),
(
"duration_field",
models.DurationField(db_index=True, default=None, null=True),
),
],
),
]
2 changes: 2 additions & 0 deletions tests/apps/test_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ class ShellCompleteTester(models.Model):
datetime_field = models.DateTimeField(null=True, default=None, db_index=True)

time_field = models.TimeField(null=True, default=None, db_index=True)

duration_field = models.DurationField(null=True, default=None, db_index=True)
62 changes: 60 additions & 2 deletions tests/test_parser_completers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
import json
import os
import re
from decimal import Decimal
from io import StringIO
from pathlib import Path
from datetime import date, datetime, time
from datetime import date, datetime, time, timedelta

from django.apps import apps
from django.core.management import CommandError, call_command
Expand Down Expand Up @@ -150,6 +149,20 @@ class TestShellCompletersAndParsers(ParserCompleterMixin, TestCase):
time(22, 30, 46, 999900),
time(23, 59, 59, 999999),
],
"duration_field": [
-timedelta(days=62, hours=13, seconds=5, microseconds=124),
-timedelta(days=51, hours=13, seconds=5, microseconds=124),
-timedelta(days=50, hours=13, seconds=5, microseconds=124),
-timedelta(days=50, hours=12, seconds=5, microseconds=124),
-timedelta(days=50, hours=12, seconds=4, microseconds=124),
-timedelta(days=50, hours=12, seconds=4, microseconds=123456),
timedelta(days=50, hours=12, seconds=4, microseconds=123456),
timedelta(days=50, hours=12, seconds=4, microseconds=124),
timedelta(days=50, hours=12, seconds=5, microseconds=124),
timedelta(days=50, hours=13, seconds=5, microseconds=124),
timedelta(days=51, hours=13, seconds=5, microseconds=124),
timedelta(days=62, hours=13, seconds=5, microseconds=124),
],
}

def test_model_object_parser_metavar(self):
Expand Down Expand Up @@ -751,6 +764,51 @@ def time_vals(completions):
},
)

def test_duration_field(self):
from django.utils.duration import duration_iso_string

def duration_vals(completions):
return list(get_values(completions))

durations = duration_vals(
self.shellcompletion.complete("model_fields test --duration ")
)
self.assertEqual(
durations,
[
"-P62DT13H00M05.000124S",
"-P51DT13H00M05.000124S",
"-P50DT13H00M05.000124S",
"-P50DT12H00M05.000124S",
"-P50DT12H00M04.123456S",
"-P50DT12H00M04.000124S",
"P50DT12H00M04.000124S",
"P50DT12H00M04.123456S",
"P50DT12H00M05.000124S",
"P50DT13H00M05.000124S",
"P51DT13H00M05.000124S",
"P62DT13H00M05.000124S",
],
)
for duration in self.field_values["duration_field"]:
self.assertEqual(
json.loads(
call_command(
"model_fields",
"test",
"--duration",
duration_iso_string(duration),
)
),
{
"duration": {
str(
ShellCompleteTester.objects.get(duration_field=duration).pk
): str(duration)
}
},
)

def test_ip_field(self):
result = call_command(
"shellcompletion",
Expand Down
33 changes: 33 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
accepts_var_kwargs,
get_win_shell,
detect_shell,
parse_iso_duration,
)
from django.test import override_settings
from django.core.management import call_command
Expand Down Expand Up @@ -106,3 +107,35 @@ def test_detection_env_fallback():

def test_detect_shell():
assert detect_shell(max_depth=256)


def test_parse_iso_duration():
from datetime import timedelta
from django.utils.duration import duration_iso_string

for duration in [
timedelta(days=3, hours=4, minutes=30, seconds=15, microseconds=123456),
timedelta(days=1, hours=12, minutes=0, seconds=0),
timedelta(days=0, hours=23, minutes=45, seconds=30),
timedelta(days=5, hours=0, minutes=15, seconds=5, microseconds=987654),
timedelta(days=2, hours=8, minutes=0, seconds=0),
timedelta(days=-3, hours=-4, minutes=-30, seconds=-15, microseconds=-123456),
timedelta(days=-1, hours=-12, minutes=0, seconds=0),
timedelta(days=-2, hours=-20, minutes=-10, seconds=-30),
timedelta(days=-5, hours=-6, minutes=-0, seconds=-50, microseconds=-123000),
timedelta(days=-10, hours=-5, minutes=-55, seconds=-5),
]:
assert parse_iso_duration(duration_iso_string(duration)) == duration

assert parse_iso_duration("") == timedelta()
assert parse_iso_duration("-") == -timedelta()
assert parse_iso_duration("+") == timedelta()

with pytest.raises(ValueError):
parse_iso_duration("?")

with pytest.raises(ValueError):
parse_iso_duration("=")

with pytest.raises(ValueError):
parse_iso_duration("10D")

0 comments on commit 7b2b0a7

Please sign in to comment.