From a80aa5ded9960d4ab7da88149c98527918fab2ec Mon Sep 17 00:00:00 2001 From: elronbandel Date: Tue, 11 Feb 2025 16:45:36 +0200 Subject: [PATCH] Add breakpoint to every unitxt catalog asset at its core logic Signed-off-by: elronbandel --- src/unitxt/debug_utils.py | 28 ++++++++++++++++++++++++++++ src/unitxt/operator.py | 1 + src/unitxt/operators.py | 5 ++++- 3 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 src/unitxt/debug_utils.py diff --git a/src/unitxt/debug_utils.py b/src/unitxt/debug_utils.py new file mode 100644 index 0000000000..b26de05c6c --- /dev/null +++ b/src/unitxt/debug_utils.py @@ -0,0 +1,28 @@ +import inspect + + +def insert_breakpoint(func): + # Retrieve the source code of the function + source_lines, starting_line_no = inspect.getsourcelines(func) + + # Determine the indentation level of the function definition + indent = len(source_lines[0]) - len(source_lines[0].lstrip()) + + # Create the new function source with a breakpoint at the beginning + new_source_lines = [] + for line in source_lines: + new_source_lines.append(line[indent:]) + if line.lstrip().startswith("def "): + # Insert a blank line and the breakpoint after the function definition + new_source_lines.append(" breakpoint()\n") + new_source_lines.append("\n" * (starting_line_no - 2)) + + new_source = "".join(new_source_lines) + + # Compile the new source and execute it in the function's global context + code = compile(new_source, func.__code__.co_filename, "exec") + func_globals = func.__globals__.copy() + exec(code, func_globals) + + # Return the modified function + return func_globals[func.__name__] diff --git a/src/unitxt/operator.py b/src/unitxt/operator.py index d1750aaae3..3f55896661 100644 --- a/src/unitxt/operator.py +++ b/src/unitxt/operator.py @@ -144,6 +144,7 @@ class StreamingOperator(Operator, PackageRequirementsMixin): As a subclass of `Artifact`, every `StreamingOperator` can be saved in a catalog for further usage or reference. """ + breakpoint: bool = FinalField(also_positional=False, default=False) @abstractmethod def __call__(self, streams: Optional[MultiStream] = None) -> MultiStream: diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index c21d8ea71e..194acb4bb5 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -65,6 +65,7 @@ from .artifact import Artifact, fetch_artifact from .dataclass import NonPositionalField, OptionalField +from .debug_utils import insert_breakpoint from .deprecation_utils import deprecation from .dict_utils import dict_delete, dict_get, dict_set, is_subpath from .error_utils import UnitxtError @@ -488,8 +489,11 @@ def process( return instance + class FieldOperator(InstanceFieldOperator): def process_instance_value(self, value: Any, instance: Dict[str, Any]): + if self.breakpoint: + return insert_breakpoint(self.process_value)(self, value) return self.process_value(value) @abstractmethod @@ -503,7 +507,6 @@ class MapValues(FieldOperator): def process_value(self, value: Any) -> Any: return self.mapping[str(value)] - class Rename(FieldOperator): """Renames fields.