Skip to content

Commit 061f86a

Browse files
committed
solve by extension class of KeyError
Signed-off-by: dafnapension <[email protected]>
1 parent 68db0ef commit 061f86a

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

src/unitxt/error_utils.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _get_existing_context(error: Exception):
109109
existing["context_object"],
110110
existing["context"],
111111
)
112-
return str(error), None, {}
112+
return error.original_error if type(error) == ExtKeyError else str(error), None, {}
113113

114114

115115
def _format_object_context(obj: Any) -> Optional[str]:
@@ -239,13 +239,22 @@ def _store_context_attributes(
239239
"original_message": original_message,
240240
}
241241
try:
242-
error.original_error = type(error)(original_message)
242+
error.original_error = (
243+
original_message
244+
if type(error) == KeyError
245+
else type(error)(original_message)
246+
)
243247
except (TypeError, ValueError):
244248
error.original_error = Exception(original_message)
245249
error.context_object = context_object
246250
error.context = context
247251

248252

253+
class ExtKeyError(KeyError):
254+
def __str__(self):
255+
return "\n" + self.args[0]
256+
257+
249258
def _add_context_to_exception(
250259
original_error: Exception, context_object: Any = None, **context
251260
):
@@ -270,6 +279,13 @@ def _add_context_to_exception(
270279
original_error.args = (formatted_message,)
271280
else:
272281
original_error.args = (original_message,)
282+
if type(original_error) == KeyError:
283+
f = ExtKeyError(original_error.args[0])
284+
f.original_error = original_error.original_error
285+
f.context_object = original_error.context_object
286+
f.context = original_error.context
287+
return f
288+
return original_error
273289

274290

275291
@contextmanager
@@ -298,7 +314,5 @@ def error_context(context_object: Any = None, **context):
298314
try:
299315
yield
300316
except Exception as e:
301-
if e.__class__.__name__ == "KeyError":
302-
e = RuntimeError(e.__class__.__name__ + ": '" + e.args[0] + "'")
303-
_add_context_to_exception(e, context_object, **context)
304-
raise e
317+
f = _add_context_to_exception(e, context_object, **context)
318+
raise f from None

tests/library/test_error_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TestProcessor:
4747

4848
processor = TestProcessor()
4949

50-
with self.assertRaises(RuntimeError) as cm:
50+
with self.assertRaises(KeyError) as cm:
5151
with error_context(processor):
5252
raise KeyError("Missing key")
5353

@@ -186,12 +186,12 @@ class TestProcessor:
186186

187187
def test_error_context_without_object(self):
188188
"""Test error_context without a context object."""
189-
with self.assertRaises(RuntimeError) as cm:
189+
with self.assertRaises(KeyError) as cm:
190190
with error_context(input_file="data.json", line_number=156):
191191
raise KeyError("Missing field")
192192

193193
error = cm.exception
194-
self.assertIsInstance(error, RuntimeError)
194+
self.assertIsInstance(error, KeyError)
195195
self.assertIsNone(error.context_object)
196196
# Context now includes version info plus the specified context
197197
self.assertIn("Unitxt", error.context)

0 commit comments

Comments
 (0)