|
1 | 1 | import types
|
2 | 2 | from contextlib import contextmanager
|
3 | 3 | from inspect import FrameInfo, stack
|
4 |
| -from typing import List, Optional |
| 4 | +from typing import ( |
| 5 | + Callable, |
| 6 | + ContextManager, |
| 7 | + Iterator, |
| 8 | + List, |
| 9 | + Optional, |
| 10 | + TypeVar, |
| 11 | + Union, |
| 12 | + overload, |
| 13 | +) |
5 | 14 |
|
6 | 15 | from returns.result import _Failure
|
7 | 16 |
|
| 17 | +_FunctionType = TypeVar('_FunctionType', bound=Callable) |
8 | 18 |
|
9 |
| -@contextmanager |
10 |
| -def collect_traces(): |
| 19 | + |
| 20 | +@overload |
| 21 | +def collect_traces() -> ContextManager[None]: |
| 22 | + """Context Manager to active traces collect to the Failures.""" |
| 23 | + |
| 24 | + |
| 25 | +@overload |
| 26 | +def collect_traces(function: _FunctionType) -> _FunctionType: |
| 27 | + """Decorator to active traces collect to the Failures.""" |
| 28 | + |
| 29 | + |
| 30 | +def collect_traces( |
| 31 | + function: Optional[_FunctionType] = None, |
| 32 | +) -> Union[_FunctionType, ContextManager[None]]: # noqa: DAR101, DAR201, DAR301 |
11 | 33 | """
|
12 | 34 | Context Manager/Decorator to active traces collect to the Failures.
|
13 | 35 |
|
@@ -36,13 +58,16 @@ def collect_traces():
|
36 | 58 | # doctest: # noqa: DAR301, E501
|
37 | 59 |
|
38 | 60 | """
|
39 |
| - unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009 |
40 |
| - substitute_get_trace = types.MethodType(_get_trace, _Failure) |
41 |
| - setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010 |
42 |
| - try: |
43 |
| - yield |
44 |
| - finally: |
45 |
| - setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010 |
| 61 | + @contextmanager |
| 62 | + def factory() -> Iterator[None]: |
| 63 | + unpatched_get_trace = getattr(_Failure, '_get_trace') # noqa: B009 |
| 64 | + substitute_get_trace = types.MethodType(_get_trace, _Failure) |
| 65 | + setattr(_Failure, '_get_trace', substitute_get_trace) # noqa: B010 |
| 66 | + try: # noqa: WPS501 |
| 67 | + yield |
| 68 | + finally: |
| 69 | + setattr(_Failure, '_get_trace', unpatched_get_trace) # noqa: B010 |
| 70 | + return factory()(function) if function else factory() |
46 | 71 |
|
47 | 72 |
|
48 | 73 | def _get_trace(_self: _Failure) -> Optional[List[FrameInfo]]:
|
|
0 commit comments