From 7b2b0a741489bb4889c3ed2ab72a85b1a749ff9d Mon Sep 17 00:00:00 2001 From: Brian Kohan Date: Tue, 28 Jan 2025 14:37:52 -0800 Subject: [PATCH] first cut at duration completer, #16 --- django_typer/completers.py | 31 +++++++++- django_typer/parsers.py | 4 ++ django_typer/utils.py | 52 ++++++++++++++-- .../management/commands/model_fields.py | 13 ++++ .../apps/test_app/migrations/0001_initial.py | 6 +- tests/apps/test_app/models.py | 2 + tests/test_parser_completers.py | 62 ++++++++++++++++++- tests/test_utils.py | 33 ++++++++++ 8 files changed, 193 insertions(+), 10 deletions(-) diff --git a/django_typer/completers.py b/django_typer/completers.py index 208c993..e704f89 100644 --- a/django_typer/completers.py +++ b/django_typer/completers.py @@ -19,7 +19,7 @@ import os import sys import typing as t -from datetime import date, time +from datetime import date, time, timedelta from functools import partial from pathlib import Path from types import MethodType @@ -35,6 +35,7 @@ DateField, DateTimeField, DecimalField, + DurationField, Field, FileField, FilePathField, @@ -82,6 +83,7 @@ class ModelObjectCompleter: - `DateField `_ **(Must use ISO 8601: YYYY-MM-DD)** - `TimeField `_ **(Must use ISO 8601: HH:MM:SS.ssssss)** - `DateTimeField `_ **(Must use ISO 8601: YYYY-MM-DDTHH:MM:SS.ssssss±HH:MM)** + - `DurationField `_ **(Must use ISO 8601: YYYY-MM-DDTHH:MM:SS.ssssss±HH:MM)** - `UUIDField `_ - `FloatField `_ - `DecimalField `_ @@ -178,6 +180,10 @@ def to_str(self, obj: t.Any) -> str: return obj.isoformat() elif isinstance(obj, date): return obj.isoformat() + elif isinstance(obj, timedelta): + from django.utils.duration import duration_iso_string + + return duration_iso_string(obj) return str(obj) def int_query(self, context: Context, parameter: Parameter, incomplete: str) -> Q: @@ -475,6 +481,27 @@ def get_tz_part(dt_str: str) -> str: **{f"{self.lookup_field}__lte": upper_bound} ) + def duration_query( + self, context: Context, parameter: Parameter, incomplete: str + ) -> Q: + """ + Default completion query builder for duratioin fields. This method will return a Q object that + will match any value that is greater than the incomplete duraton string (or less if negative). + All durations must be in ISO8601 format (YYYY-MM-DD). Week specifiers are not supported. + + :param context: The click context. + :param parameter: The click parameter. + :param incomplete: The incomplete string. + :return: A Q object to use for filtering the queryset. + :raises ValueError: If the incomplete string is not a valid partial duration. + :raises AssertionError: If the incomplete string is not a valid partial duration. + """ + from django_typer.utils import parse_iso_duration + + duration = parse_iso_duration(incomplete) + lookup = "gte" if duration >= timedelta() else "lte" + return Q(**{f"{self.lookup_field}__{lookup}": duration}) + def __init__( self, model_or_qry: t.Union[t.Type[Model], QuerySet], @@ -530,6 +557,8 @@ def __init__( self.query = self.date_query elif isinstance(self._field, TimeField): self.query = self.time_query + elif isinstance(self._field, DurationField): + self.query = self.duration_query else: raise ValueError( _("Unsupported lookup field class: {cls}").format( diff --git a/django_typer/parsers.py b/django_typer/parsers.py index 516d693..3e117d6 100644 --- a/django_typer/parsers.py +++ b/django_typer/parsers.py @@ -162,6 +162,10 @@ def convert( value = date.fromisoformat(value) elif isinstance(self._field, models.TimeField): value = time.fromisoformat(value) + elif isinstance(self._field, models.DurationField): + from django_typer.utils import parse_iso_duration + + value = parse_iso_duration(value) return self.model_cls.objects.get( **{f"{self.lookup_field}{self._lookup}": value} ) diff --git a/django_typer/utils.py b/django_typer/utils.py index c654c6c..018de66 100644 --- a/django_typer/utils.py +++ b/django_typer/utils.py @@ -2,25 +2,21 @@ A collection of useful utilities. """ -import importlib import inspect import os -import pkgutil -import shutil import sys import typing as t +from datetime import timedelta from functools import partial from pathlib import Path from threading import local from types import MethodType, ModuleType -from shellingham import ShellDetectionFailure -from shellingham import detect_shell as _detect_shell - from .config import traceback_config # DO NOT IMPORT ANYTHING FROM TYPER HERE - SEE patch.py + __all__ = [ "detect_shell", "get_usage_script", @@ -40,6 +36,9 @@ def detect_shell(max_depth: int = 10) -> t.Tuple[str, str]: :raises ShellDetectionFailure: If the shell cannot be detected :return: A tuple of the shell name and the shell command """ + from shellingham import ShellDetectionFailure + from shellingham import detect_shell as _detect_shell + try: return _detect_shell(max_depth=max_depth) except ShellDetectionFailure: @@ -57,6 +56,8 @@ def get_usage_script(script: t.Optional[str] = None) -> t.Union[Path, str]: :param script: The script name to check. If None the current script is used. :return: The script name or the relative path to the script from cwd. """ + import shutil + cmd_pth = Path(script or sys.argv[0]) on_path: t.Optional[t.Union[str, Path]] = shutil.which(cmd_pth.name) on_path = on_path and Path(on_path) @@ -143,6 +144,8 @@ def ready(self): :param commands: The names of the commands/modules, if not provided, all modules in the package will be registered as plugins """ + import pkgutil + commands = commands or [ module[1].split(".")[-1] for module in pkgutil.iter_modules(package.__path__, f"{package.__name__}.") @@ -163,6 +166,8 @@ def _load_command_plugins(command: str) -> int: """ plugins = _command_plugins.get(command, []) if plugins: + import importlib + for ext_pkg in reversed(plugins): try: importlib.import_module(f"{ext_pkg.__name__}.{command}") @@ -255,6 +260,7 @@ def get_win_shell() -> str: :return: The name of the shell, either 'powershell' or 'pwsh' """ import json + import shutil import subprocess from shellingham import ShellDetectionFailure @@ -288,3 +294,37 @@ def get_win_shell() -> str: raise ShellDetectionFailure("Unable to detect windows shell") from e raise ShellDetectionFailure("Unable to detect windows shell") + + +def parse_iso_duration(duration: str) -> timedelta: + """ + Progressively parse an ISO8601 duration type. + """ + import re + + # Define regex pattern for ISO 8601 duration + pattern = re.compile( + r"([-+])?" + r"(P" # Start with 'P' + r"(?:(?P\d+)D)?" # Capture days (optional) + r"(?:T" # Start time part (optional) + r"(?:(?P\d+)H)?" # Capture hours (optional) + r"(?:(?P\d+)M)?" # Capture minutes (optional) + r"(?:(?P\d+(?:\.\d+)?)S)?)?)?" # Capture seconds (optional, with optional fraction) + ) + + # Match the input string to the pattern + match = pattern.fullmatch(duration) + if not match: + raise ValueError(f"Invalid ISO 8601 duration format: {duration}") + + # Extract matched groups and convert to float/int + days = int(match.group("days")) if match.group("days") else 0 + hours = int(match.group("hours")) if match.group("hours") else 0 + minutes = int(match.group("minutes")) if match.group("minutes") else 0 + seconds = float(match.group("seconds")) if match.group("seconds") else 0.0 + + td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + if duration.startswith("-"): + return -td + return td diff --git a/tests/apps/test_app/management/commands/model_fields.py b/tests/apps/test_app/management/commands/model_fields.py index 43188e3..8339f7c 100644 --- a/tests/apps/test_app/management/commands/model_fields.py +++ b/tests/apps/test_app/management/commands/model_fields.py @@ -194,6 +194,15 @@ def test( help=t.cast(str, _("Fetch objects by their time fields.")), ), ] = None, + duration: Annotated[ + t.Optional[ShellCompleteTester], + typer.Option( + **model_parser_completer( + ShellCompleteTester, "duration_field", order_by="duration_field" + ), + help=t.cast(str, _("Fetch objects by their duration fields.")), + ), + ] = None, ): assert self.__class__ is Command objects = {} @@ -252,4 +261,8 @@ def test( if time is not None: assert isinstance(time, ShellCompleteTester) objects["time"] = {time.id: str(time.time_field)} + + if duration is not None: + assert isinstance(duration, ShellCompleteTester) + objects["duration"] = {duration.id: str(duration.duration_field)} return json.dumps(objects) diff --git a/tests/apps/test_app/migrations/0001_initial.py b/tests/apps/test_app/migrations/0001_initial.py index 6789cde..2935548 100644 --- a/tests/apps/test_app/migrations/0001_initial.py +++ b/tests/apps/test_app/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.18 on 2025-01-21 04:09 +# Generated by Django 4.2.18 on 2025-01-28 21:51 from django.db import migrations, models @@ -77,6 +77,10 @@ class Migration(migrations.Migration): "time_field", models.TimeField(db_index=True, default=None, null=True), ), + ( + "duration_field", + models.DurationField(db_index=True, default=None, null=True), + ), ], ), ] diff --git a/tests/apps/test_app/models.py b/tests/apps/test_app/models.py index af748b8..fe66ba9 100644 --- a/tests/apps/test_app/models.py +++ b/tests/apps/test_app/models.py @@ -29,3 +29,5 @@ class ShellCompleteTester(models.Model): datetime_field = models.DateTimeField(null=True, default=None, db_index=True) time_field = models.TimeField(null=True, default=None, db_index=True) + + duration_field = models.DurationField(null=True, default=None, db_index=True) diff --git a/tests/test_parser_completers.py b/tests/test_parser_completers.py index 359aa23..bc1f4de 100644 --- a/tests/test_parser_completers.py +++ b/tests/test_parser_completers.py @@ -1,11 +1,10 @@ -import contextlib import json import os import re from decimal import Decimal from io import StringIO from pathlib import Path -from datetime import date, datetime, time +from datetime import date, datetime, time, timedelta from django.apps import apps from django.core.management import CommandError, call_command @@ -150,6 +149,20 @@ class TestShellCompletersAndParsers(ParserCompleterMixin, TestCase): time(22, 30, 46, 999900), time(23, 59, 59, 999999), ], + "duration_field": [ + -timedelta(days=62, hours=13, seconds=5, microseconds=124), + -timedelta(days=51, hours=13, seconds=5, microseconds=124), + -timedelta(days=50, hours=13, seconds=5, microseconds=124), + -timedelta(days=50, hours=12, seconds=5, microseconds=124), + -timedelta(days=50, hours=12, seconds=4, microseconds=124), + -timedelta(days=50, hours=12, seconds=4, microseconds=123456), + timedelta(days=50, hours=12, seconds=4, microseconds=123456), + timedelta(days=50, hours=12, seconds=4, microseconds=124), + timedelta(days=50, hours=12, seconds=5, microseconds=124), + timedelta(days=50, hours=13, seconds=5, microseconds=124), + timedelta(days=51, hours=13, seconds=5, microseconds=124), + timedelta(days=62, hours=13, seconds=5, microseconds=124), + ], } def test_model_object_parser_metavar(self): @@ -751,6 +764,51 @@ def time_vals(completions): }, ) + def test_duration_field(self): + from django.utils.duration import duration_iso_string + + def duration_vals(completions): + return list(get_values(completions)) + + durations = duration_vals( + self.shellcompletion.complete("model_fields test --duration ") + ) + self.assertEqual( + durations, + [ + "-P62DT13H00M05.000124S", + "-P51DT13H00M05.000124S", + "-P50DT13H00M05.000124S", + "-P50DT12H00M05.000124S", + "-P50DT12H00M04.123456S", + "-P50DT12H00M04.000124S", + "P50DT12H00M04.000124S", + "P50DT12H00M04.123456S", + "P50DT12H00M05.000124S", + "P50DT13H00M05.000124S", + "P51DT13H00M05.000124S", + "P62DT13H00M05.000124S", + ], + ) + for duration in self.field_values["duration_field"]: + self.assertEqual( + json.loads( + call_command( + "model_fields", + "test", + "--duration", + duration_iso_string(duration), + ) + ), + { + "duration": { + str( + ShellCompleteTester.objects.get(duration_field=duration).pk + ): str(duration) + } + }, + ) + def test_ip_field(self): result = call_command( "shellcompletion", diff --git a/tests/test_utils.py b/tests/test_utils.py index 6d12e79..3dd67aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ accepts_var_kwargs, get_win_shell, detect_shell, + parse_iso_duration, ) from django.test import override_settings from django.core.management import call_command @@ -106,3 +107,35 @@ def test_detection_env_fallback(): def test_detect_shell(): assert detect_shell(max_depth=256) + + +def test_parse_iso_duration(): + from datetime import timedelta + from django.utils.duration import duration_iso_string + + for duration in [ + timedelta(days=3, hours=4, minutes=30, seconds=15, microseconds=123456), + timedelta(days=1, hours=12, minutes=0, seconds=0), + timedelta(days=0, hours=23, minutes=45, seconds=30), + timedelta(days=5, hours=0, minutes=15, seconds=5, microseconds=987654), + timedelta(days=2, hours=8, minutes=0, seconds=0), + timedelta(days=-3, hours=-4, minutes=-30, seconds=-15, microseconds=-123456), + timedelta(days=-1, hours=-12, minutes=0, seconds=0), + timedelta(days=-2, hours=-20, minutes=-10, seconds=-30), + timedelta(days=-5, hours=-6, minutes=-0, seconds=-50, microseconds=-123000), + timedelta(days=-10, hours=-5, minutes=-55, seconds=-5), + ]: + assert parse_iso_duration(duration_iso_string(duration)) == duration + + assert parse_iso_duration("") == timedelta() + assert parse_iso_duration("-") == -timedelta() + assert parse_iso_duration("+") == timedelta() + + with pytest.raises(ValueError): + parse_iso_duration("?") + + with pytest.raises(ValueError): + parse_iso_duration("=") + + with pytest.raises(ValueError): + parse_iso_duration("10D")