diff --git a/ddtrace/internal/wrapping/__init__.py b/ddtrace/internal/wrapping/__init__.py index 852c99dc151..53a0b6c35b1 100644 --- a/ddtrace/internal/wrapping/__init__.py +++ b/ddtrace/internal/wrapping/__init__.py @@ -1,12 +1,14 @@ import sys +from types import CodeType from types import FunctionType -from typing import Any # noqa:F401 -from typing import Callable # noqa:F401 -from typing import Dict # noqa:F401 -from typing import Optional # noqa:F401 -from typing import Protocol # noqa:F401 -from typing import Tuple # noqa:F401 -from typing import cast # noqa:F401 +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generator +from typing import Optional +from typing import Protocol +from typing import Tuple +from typing import cast import bytecode as bc from bytecode import Instr @@ -22,23 +24,53 @@ class WrappedFunction(Protocol): """A wrapped function.""" - __dd_wrapped__ = None # type: Optional[FunctionType] - __dd_wrappers__ = None # type: Optional[Dict[Any, Any]] + __dd_wrapped__: Optional[FunctionType] = None + __dd_wrappers__: Optional[Dict[Any, Any]] = None - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: pass Wrapper = Callable[[FunctionType, Tuple[Any], Dict[str, Any]], Any] -def _add(lineno): +def _add(lineno: int) -> Instr: if PY >= (3, 11): return Instr("BINARY_OP", bc.BinaryOp.ADD, lineno=lineno) return Instr("INPLACE_ADD", lineno=lineno) +HEAD = Assembly() +if PY >= (3, 13): + HEAD.parse( + r""" + resume 0 + load_const {wrapper} + push_null + load_const {wrapped} + """ + ) + +elif PY >= (3, 11): + HEAD.parse( + r""" + resume 0 + push_null + load_const {wrapper} + load_const {wrapped} + """ + ) + +else: + HEAD.parse( + r""" + load_const {wrapper} + load_const {wrapped} + """ + ) + + UPDATE_MAP = Assembly() if PY >= (3, 12): UPDATE_MAP.parse( @@ -62,6 +94,7 @@ def _add(lineno): pop_top """ ) + else: UPDATE_MAP.parse( r""" @@ -104,8 +137,68 @@ def _add(lineno): FIRSTLINENO_OFFSET = int(PY >= (3, 11)) -def wrap_bytecode(wrapper, wrapped): - # type: (Wrapper, FunctionType) -> bc.Bytecode +def generate_posargs(code: CodeType) -> Generator[Instr, None, None]: + """Generate the opcodes for building the positional arguments tuple.""" + varnames = code.co_varnames + lineno = code.co_firstlineno + FIRSTLINENO_OFFSET + varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS) + nargs = code.co_argcount + varargsname = varnames[nargs + code.co_kwonlyargcount] if varargs else None + + if nargs: # posargs [+ varargs] + yield from ( + Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno) + if PY >= (3, 11) and argname in code.co_cellvars + else Instr("LOAD_FAST", argname, lineno=lineno) + for argname in varnames[:nargs] + ) + + yield Instr("BUILD_TUPLE", nargs, lineno=lineno) + if varargs: + yield Instr("LOAD_FAST", varargsname, lineno=lineno) + yield _add(lineno) + + elif varargs: # varargs + yield Instr("LOAD_FAST", varargsname, lineno=lineno) + + else: # () + yield Instr("BUILD_TUPLE", 0, lineno=lineno) + + +(PAIR := Assembly()).parse( + r""" + load_const {arg} + load_fast {arg} + """ +) + + +def generate_kwargs(code: CodeType) -> Generator[Instr, None, None]: + """Generate the opcodes for building the keyword arguments dictionary.""" + flags = code.co_flags + varnames = code.co_varnames + lineno = code.co_firstlineno + FIRSTLINENO_OFFSET + varargs = bool(flags & bc.CompilerFlags.VARARGS) + varkwargs = bool(flags & bc.CompilerFlags.VARKEYWORDS) + nargs = code.co_argcount + kwonlyargs = code.co_kwonlyargcount + varkwargsname = varnames[nargs + kwonlyargs + varargs] if varkwargs else None + + if kwonlyargs: + for arg in varnames[nargs : nargs + kwonlyargs]: # kwargs [+ varkwargs] + yield from PAIR.bind({"arg": arg}, lineno=lineno) + yield Instr("BUILD_MAP", kwonlyargs, lineno=lineno) + if varkwargs: + yield from UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno) + + elif varkwargs: # varkwargs + yield Instr("LOAD_FAST", varkwargsname, lineno=lineno) + + else: # {} + yield Instr("BUILD_MAP", 0, lineno=lineno) + + +def wrap_bytecode(wrapper: Wrapper, wrapped: FunctionType) -> bc.Bytecode: """Wrap a function with a wrapper function. The wrapper function expects the wrapped function as the first argument, @@ -118,97 +211,42 @@ def wrap_bytecode(wrapper, wrapped): code = wrapped.__code__ lineno = code.co_firstlineno + FIRSTLINENO_OFFSET - varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS) - varkwargs = bool(code.co_flags & bc.CompilerFlags.VARKEYWORDS) - nargs = code.co_argcount - argnames = code.co_varnames[:nargs] - try: - kwonlyargs = code.co_kwonlyargcount - except AttributeError: - kwonlyargs = 0 - kwonlyargnames = code.co_varnames[nargs : nargs + kwonlyargs] - varargsname = code.co_varnames[nargs + kwonlyargs] if varargs else None - varkwargsname = code.co_varnames[nargs + kwonlyargs + varargs] if varkwargs else None # Push the wrapper function that is to be called and the wrapped function to # be passed as first argument. - instrs = [ - bc.Instr("LOAD_CONST", wrapper, lineno=lineno), - bc.Instr("LOAD_CONST", wrapped, lineno=lineno), - ] - if PY >= (3, 11): - # From insert_prefix_instructions - instrs[0:0] = [ - bc.Instr("RESUME", 0, lineno=lineno - 1), - bc.Instr("PUSH_NULL", lineno=lineno), - ] - if PY >= (3, 13): - instrs[1], instrs[2] = instrs[2], instrs[1] - - if code.co_cellvars: - instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars] + instrs = HEAD.bind({"wrapper": wrapper, "wrapped": wrapped}, lineno=lineno) - if code.co_freevars: - instrs.insert(0, bc.Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno)) - - # Build the tuple of all the positional arguments - if nargs: - instrs.extend( - [ - Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno) - if PY >= (3, 11) and argname in code.co_cellvars - else bc.Instr("LOAD_FAST", argname, lineno=lineno) - for argname in argnames - ] - ) - instrs.append(bc.Instr("BUILD_TUPLE", nargs, lineno=lineno)) - if varargs: - instrs.extend( - [ - bc.Instr("LOAD_FAST", varargsname, lineno=lineno), - _add(lineno), - ] - ) - elif varargs: - instrs.append(bc.Instr("LOAD_FAST", varargsname, lineno=lineno)) - else: - instrs.append(bc.Instr("BUILD_TUPLE", 0, lineno=lineno)) - - # Prepare the keyword arguments - if kwonlyargs: - for arg in kwonlyargnames: - instrs.extend( - [ - bc.Instr("LOAD_CONST", arg, lineno=lineno), - bc.Instr("LOAD_FAST", arg, lineno=lineno), - ] - ) - instrs.append(bc.Instr("BUILD_MAP", kwonlyargs, lineno=lineno)) - if varkwargs: - instrs.extend(UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno)) + # Add positional arguments + instrs.extend(generate_posargs(code)) - elif varkwargs: - instrs.append(bc.Instr("LOAD_FAST", varkwargsname, lineno=lineno)) - - else: - instrs.append(bc.Instr("BUILD_MAP", 0, lineno=lineno)) + # Add keyword arguments + instrs.extend(generate_kwargs(code)) # Call the wrapper function with the wrapped function, the positional and - # keyword arguments, and return the result. + # keyword arguments, and return the result. This is equivalent to + # + # >>> return wrapper(wrapped, args, kwargs) instrs.extend(CALL_RETURN.bind({"arg": 3}, lineno=lineno)) + # Include code for handling free/cell variables, if needed + if PY >= (3, 11): + if code.co_cellvars: + instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars] + + if code.co_freevars: + instrs.insert(0, Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno)) + # If the function has special flags set, like the generator, async generator # or coroutine, inject unraveling code before the return opcode. - if bc.CompilerFlags.GENERATOR & code.co_flags and not (bc.CompilerFlags.COROUTINE & code.co_flags): + if (bc.CompilerFlags.GENERATOR & code.co_flags) and not (bc.CompilerFlags.COROUTINE & code.co_flags): wrap_generator(instrs, code, lineno) else: wrap_async(instrs, code, lineno) - return bc.Bytecode(instrs) + return instrs -def wrap(f, wrapper): - # type: (FunctionType, Wrapper) -> WrappedFunction +def wrap(f: FunctionType, wrapper: Wrapper) -> WrappedFunction: """Wrap a function with a wrapper. The wrapper expects the function as first argument, followed by the tuple @@ -218,7 +256,7 @@ def wrap(f, wrapper): wrapper function, instead of creating a new function object. """ wrapped = FunctionType( - f.__code__, + code := f.__code__, f.__globals__, "", f.__defaults__, @@ -232,29 +270,31 @@ def wrap(f, wrapper): wrapped.__kwdefaults__ = f.__kwdefaults__ - code = wrap_bytecode(wrapper, wrapped) - code.freevars = f.__code__.co_freevars - if PY >= (3, 11): - code.cellvars = f.__code__.co_cellvars - code.name = f.__code__.co_name - code.filename = f.__code__.co_filename - code.flags = f.__code__.co_flags - code.argcount = f.__code__.co_argcount - try: - code.posonlyargcount = f.__code__.co_posonlyargcount - except AttributeError: - pass + flags = code.co_flags + nargs = ( + (argcount := code.co_argcount) + + (kwonlycount := code.co_kwonlyargcount) + + bool(flags & bc.CompilerFlags.VARARGS) + + bool(flags & bc.CompilerFlags.VARKEYWORDS) + ) - nargs = code.argcount - try: - code.kwonlyargcount = f.__code__.co_kwonlyargcount - nargs += code.kwonlyargcount - except AttributeError: - pass - nargs += bool(code.flags & bc.CompilerFlags.VARARGS) + bool(code.flags & bc.CompilerFlags.VARKEYWORDS) - code.argnames = f.__code__.co_varnames[:nargs] + # Wrap the wrapped function with the wrapper + wrapped_code = wrap_bytecode(wrapper, wrapped) + + # Copy over the code attributes + wrapped_code.argcount = argcount + wrapped_code.argnames = code.co_varnames[:nargs] + wrapped_code.filename = code.co_filename + wrapped_code.freevars = code.co_freevars + wrapped_code.flags = flags + wrapped_code.kwonlyargcount = kwonlycount + wrapped_code.name = code.co_name + wrapped_code.posonlyargcount = code.co_posonlyargcount + if PY >= (3, 11): + wrapped_code.cellvars = code.co_cellvars - f.__code__ = code.to_code() + # Replace the function code with the trampoline bytecode + f.__code__ = wrapped_code.to_code() # DEV: Multiple wrapping is implemented as a singly-linked list via the # __dd_wrapped__ attribute. @@ -296,8 +336,7 @@ def is_wrapped_with(f: FunctionType, wrapper: Wrapper) -> bool: return False -def unwrap(wf, wrapper): - # type: (WrappedFunction, Wrapper) -> FunctionType +def unwrap(wf: WrappedFunction, wrapper: Wrapper) -> FunctionType: """Unwrap a wrapped function. This is the reverse of :func:`wrap`. In case of multiple wrapping layers,