Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 150 additions & 111 deletions ddtrace/internal/wrapping/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
from types import CodeType
from types import FunctionType
from typing import Any # noqa:F401
from typing import Callable # noqa:F401
from typing import Dict # noqa:F401
from typing import Optional # noqa:F401
from typing import Protocol # noqa:F401
from typing import Tuple # noqa:F401
from typing import cast # noqa:F401
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import cast

import bytecode as bc
from bytecode import Instr
Expand All @@ -22,23 +24,53 @@
class WrappedFunction(Protocol):
"""A wrapped function."""

__dd_wrapped__ = None # type: Optional[FunctionType]
__dd_wrappers__ = None # type: Optional[Dict[Any, Any]]
__dd_wrapped__: Optional[FunctionType] = None
__dd_wrappers__: Optional[Dict[Any, Any]] = None

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
pass


Wrapper = Callable[[FunctionType, Tuple[Any], Dict[str, Any]], Any]


def _add(lineno):
def _add(lineno: int) -> Instr:
if PY >= (3, 11):
return Instr("BINARY_OP", bc.BinaryOp.ADD, lineno=lineno)

return Instr("INPLACE_ADD", lineno=lineno)


HEAD = Assembly()
if PY >= (3, 13):
HEAD.parse(
r"""
resume 0
load_const {wrapper}
push_null
load_const {wrapped}
"""
)

elif PY >= (3, 11):
HEAD.parse(
r"""
resume 0
push_null
load_const {wrapper}
load_const {wrapped}
"""
)

else:
HEAD.parse(
r"""
load_const {wrapper}
load_const {wrapped}
"""
)


UPDATE_MAP = Assembly()
if PY >= (3, 12):
UPDATE_MAP.parse(
Expand All @@ -62,6 +94,7 @@ def _add(lineno):
pop_top
"""
)

else:
UPDATE_MAP.parse(
r"""
Expand Down Expand Up @@ -104,8 +137,68 @@ def _add(lineno):
FIRSTLINENO_OFFSET = int(PY >= (3, 11))


def wrap_bytecode(wrapper, wrapped):
# type: (Wrapper, FunctionType) -> bc.Bytecode
def generate_posargs(code: CodeType) -> Generator[Instr, None, None]:
"""Generate the opcodes for building the positional arguments tuple."""
varnames = code.co_varnames
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS)
nargs = code.co_argcount
varargsname = varnames[nargs + code.co_kwonlyargcount] if varargs else None

if nargs: # posargs [+ varargs]
yield from (
Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno)
if PY >= (3, 11) and argname in code.co_cellvars
else Instr("LOAD_FAST", argname, lineno=lineno)
for argname in varnames[:nargs]
)

yield Instr("BUILD_TUPLE", nargs, lineno=lineno)
if varargs:
yield Instr("LOAD_FAST", varargsname, lineno=lineno)
yield _add(lineno)

elif varargs: # varargs
yield Instr("LOAD_FAST", varargsname, lineno=lineno)

else: # ()
yield Instr("BUILD_TUPLE", 0, lineno=lineno)


(PAIR := Assembly()).parse(
r"""
load_const {arg}
load_fast {arg}
"""
)


def generate_kwargs(code: CodeType) -> Generator[Instr, None, None]:
"""Generate the opcodes for building the keyword arguments dictionary."""
flags = code.co_flags
varnames = code.co_varnames
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
varargs = bool(flags & bc.CompilerFlags.VARARGS)
varkwargs = bool(flags & bc.CompilerFlags.VARKEYWORDS)
nargs = code.co_argcount
kwonlyargs = code.co_kwonlyargcount
varkwargsname = varnames[nargs + kwonlyargs + varargs] if varkwargs else None

if kwonlyargs:
for arg in varnames[nargs : nargs + kwonlyargs]: # kwargs [+ varkwargs]
yield from PAIR.bind({"arg": arg}, lineno=lineno)
yield Instr("BUILD_MAP", kwonlyargs, lineno=lineno)
if varkwargs:
yield from UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno)

elif varkwargs: # varkwargs
yield Instr("LOAD_FAST", varkwargsname, lineno=lineno)

else: # {}
yield Instr("BUILD_MAP", 0, lineno=lineno)


def wrap_bytecode(wrapper: Wrapper, wrapped: FunctionType) -> bc.Bytecode:
"""Wrap a function with a wrapper function.

The wrapper function expects the wrapped function as the first argument,
Expand All @@ -118,97 +211,42 @@ def wrap_bytecode(wrapper, wrapped):

code = wrapped.__code__
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS)
varkwargs = bool(code.co_flags & bc.CompilerFlags.VARKEYWORDS)
nargs = code.co_argcount
argnames = code.co_varnames[:nargs]
try:
kwonlyargs = code.co_kwonlyargcount
except AttributeError:
kwonlyargs = 0
kwonlyargnames = code.co_varnames[nargs : nargs + kwonlyargs]
varargsname = code.co_varnames[nargs + kwonlyargs] if varargs else None
varkwargsname = code.co_varnames[nargs + kwonlyargs + varargs] if varkwargs else None

# Push the wrapper function that is to be called and the wrapped function to
# be passed as first argument.
instrs = [
bc.Instr("LOAD_CONST", wrapper, lineno=lineno),
bc.Instr("LOAD_CONST", wrapped, lineno=lineno),
]
if PY >= (3, 11):
# From insert_prefix_instructions
instrs[0:0] = [
bc.Instr("RESUME", 0, lineno=lineno - 1),
bc.Instr("PUSH_NULL", lineno=lineno),
]
if PY >= (3, 13):
instrs[1], instrs[2] = instrs[2], instrs[1]

if code.co_cellvars:
instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars]
instrs = HEAD.bind({"wrapper": wrapper, "wrapped": wrapped}, lineno=lineno)

if code.co_freevars:
instrs.insert(0, bc.Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno))

# Build the tuple of all the positional arguments
if nargs:
instrs.extend(
[
Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno)
if PY >= (3, 11) and argname in code.co_cellvars
else bc.Instr("LOAD_FAST", argname, lineno=lineno)
for argname in argnames
]
)
instrs.append(bc.Instr("BUILD_TUPLE", nargs, lineno=lineno))
if varargs:
instrs.extend(
[
bc.Instr("LOAD_FAST", varargsname, lineno=lineno),
_add(lineno),
]
)
elif varargs:
instrs.append(bc.Instr("LOAD_FAST", varargsname, lineno=lineno))
else:
instrs.append(bc.Instr("BUILD_TUPLE", 0, lineno=lineno))

# Prepare the keyword arguments
if kwonlyargs:
for arg in kwonlyargnames:
instrs.extend(
[
bc.Instr("LOAD_CONST", arg, lineno=lineno),
bc.Instr("LOAD_FAST", arg, lineno=lineno),
]
)
instrs.append(bc.Instr("BUILD_MAP", kwonlyargs, lineno=lineno))
if varkwargs:
instrs.extend(UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno))
# Add positional arguments
instrs.extend(generate_posargs(code))

elif varkwargs:
instrs.append(bc.Instr("LOAD_FAST", varkwargsname, lineno=lineno))

else:
instrs.append(bc.Instr("BUILD_MAP", 0, lineno=lineno))
# Add keyword arguments
instrs.extend(generate_kwargs(code))

# Call the wrapper function with the wrapped function, the positional and
# keyword arguments, and return the result.
# keyword arguments, and return the result. This is equivalent to
#
# >>> return wrapper(wrapped, args, kwargs)
instrs.extend(CALL_RETURN.bind({"arg": 3}, lineno=lineno))

# Include code for handling free/cell variables, if needed
if PY >= (3, 11):
if code.co_cellvars:
instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars]

if code.co_freevars:
instrs.insert(0, Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno))

# If the function has special flags set, like the generator, async generator
# or coroutine, inject unraveling code before the return opcode.
if bc.CompilerFlags.GENERATOR & code.co_flags and not (bc.CompilerFlags.COROUTINE & code.co_flags):
if (bc.CompilerFlags.GENERATOR & code.co_flags) and not (bc.CompilerFlags.COROUTINE & code.co_flags):
wrap_generator(instrs, code, lineno)
else:
wrap_async(instrs, code, lineno)

return bc.Bytecode(instrs)
return instrs


def wrap(f, wrapper):
# type: (FunctionType, Wrapper) -> WrappedFunction
def wrap(f: FunctionType, wrapper: Wrapper) -> WrappedFunction:
"""Wrap a function with a wrapper.

The wrapper expects the function as first argument, followed by the tuple
Expand All @@ -218,7 +256,7 @@ def wrap(f, wrapper):
wrapper function, instead of creating a new function object.
"""
wrapped = FunctionType(
f.__code__,
code := f.__code__,
f.__globals__,
"<wrapped>",
f.__defaults__,
Expand All @@ -232,29 +270,31 @@ def wrap(f, wrapper):

wrapped.__kwdefaults__ = f.__kwdefaults__

code = wrap_bytecode(wrapper, wrapped)
code.freevars = f.__code__.co_freevars
if PY >= (3, 11):
code.cellvars = f.__code__.co_cellvars
code.name = f.__code__.co_name
code.filename = f.__code__.co_filename
code.flags = f.__code__.co_flags
code.argcount = f.__code__.co_argcount
try:
code.posonlyargcount = f.__code__.co_posonlyargcount
except AttributeError:
pass
flags = code.co_flags
nargs = (
(argcount := code.co_argcount)
+ (kwonlycount := code.co_kwonlyargcount)
+ bool(flags & bc.CompilerFlags.VARARGS)
+ bool(flags & bc.CompilerFlags.VARKEYWORDS)
)

nargs = code.argcount
try:
code.kwonlyargcount = f.__code__.co_kwonlyargcount
nargs += code.kwonlyargcount
except AttributeError:
pass
nargs += bool(code.flags & bc.CompilerFlags.VARARGS) + bool(code.flags & bc.CompilerFlags.VARKEYWORDS)
code.argnames = f.__code__.co_varnames[:nargs]
# Wrap the wrapped function with the wrapper
wrapped_code = wrap_bytecode(wrapper, wrapped)

# Copy over the code attributes
wrapped_code.argcount = argcount
wrapped_code.argnames = code.co_varnames[:nargs]
wrapped_code.filename = code.co_filename
wrapped_code.freevars = code.co_freevars
wrapped_code.flags = flags
wrapped_code.kwonlyargcount = kwonlycount
wrapped_code.name = code.co_name
wrapped_code.posonlyargcount = code.co_posonlyargcount
if PY >= (3, 11):
wrapped_code.cellvars = code.co_cellvars

f.__code__ = code.to_code()
# Replace the function code with the trampoline bytecode
f.__code__ = wrapped_code.to_code()

# DEV: Multiple wrapping is implemented as a singly-linked list via the
# __dd_wrapped__ attribute.
Expand Down Expand Up @@ -296,8 +336,7 @@ def is_wrapped_with(f: FunctionType, wrapper: Wrapper) -> bool:
return False


def unwrap(wf, wrapper):
# type: (WrappedFunction, Wrapper) -> FunctionType
def unwrap(wf: WrappedFunction, wrapper: Wrapper) -> FunctionType:
"""Unwrap a wrapped function.

This is the reverse of :func:`wrap`. In case of multiple wrapping layers,
Expand Down
Loading