diff --git a/docs/advanced.rst b/docs/advanced.rst index ed833f1f..e067d992 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -440,3 +440,23 @@ option to ``True``, VCR will not save old HTTP interactions if they are not used my_vcr = VCR(drop_unused_requests=True) with my_vcr.use_cassette('fixtures/vcr_cassettes/synopsis.yaml'): ... # your HTTP interactions here + +Metadata +-------------------- + +Sometimes there are external factors that affect the HTTP interactions. In order +to keep track of those, you can store custom metadata in a cassette. For example, +this might be useful to store values generated outside the code, like the resulting +values of a redirection to an external website, or to store seeds to deterministically +generate values that are included in the HTTP requests. + +.. code:: python + + my_vcr = VCR() + with my_vcr.use_cassette('fixtures/vcr_cassettes/synopsis.yaml') as cass: + seed = cass.get_metadata("seed", random.Random().getrandbits(128)) + r = Random(seed) + with mock.patch("uuid.uuid4") as uuid4: + uuid4.side_effect = lambda: uuid.UUID(int=r.getrandbits(128), version=4) + + ... # your HTTP interactions here diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index f2431e44..cfb38f9a 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -9,11 +9,17 @@ import yaml from vcr.cassette import Cassette -from vcr.errors import UnhandledHTTPRequestError +from vcr.errors import CannotOverwriteExistingCassetteException, UnhandledHTTPRequestError from vcr.patch import force_reset from vcr.request import Request from vcr.stubs import VCRHTTPSConnection +# Use the libYAML versions if possible +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + def test_cassette_load(tmpdir): a_file = tmpdir.join("test_cassette.yml") @@ -433,3 +439,42 @@ def test_used_interactions(tmpdir): used_interactions = cassette._played_interactions + cassette._new_interactions() assert len(used_interactions) == 2 + + +def test_metadata_write(tmpdir): + file = tmpdir.join("test_cassette.yml") + cassette = Cassette(path=str(file)) + + assert cassette.get_metadata("key", "value") == "value" + assert cassette.get_metadata("key", "default") == "value" + assert cassette.get_metadata("otherkey", "othervalue") == "othervalue" + + cassette._save(force=False) + with open(file) as f: + assert yaml.load(f, Loader=Loader)["metadata"] == { + "key": "value", + "otherkey": "othervalue", + } + + +def test_metadata_load(tmpdir): + file = tmpdir.join("test_cassette.yml") + file.write( + yaml.dump( + { + "interactions": [ + { + "request": {"body": "", "uri": "foo1", "method": "GET", "headers": {}}, + "response": "bar1", + }, + ], + "metadata": {"key": "value"}, + }, + ), + ) + + cassette = Cassette.load(path=str(file)) + assert cassette.get_metadata("key", "default") == "value" + + with pytest.raises(CannotOverwriteExistingCassetteException): + cassette.get_metadata("otherkey", "default") diff --git a/vcr/cassette.py b/vcr/cassette.py index 1ac06dde..1afc19b4 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -3,12 +3,13 @@ import copy import inspect import logging +import typing from inspect import iscoroutinefunction import wrapt from ._handle_coroutine import handle_coroutine -from .errors import UnhandledHTTPRequestError +from .errors import CannotOverwriteExistingCassetteException, UnhandledHTTPRequestError from .matchers import get_matchers_results, method, requests_match, uri from .patch import CassettePatcherBuilder from .persisters.filesystem import CassetteDecodeError, CassetteNotFoundError, FilesystemPersister @@ -17,6 +18,7 @@ from .util import partition_dict log = logging.getLogger(__name__) +T = typing.TypeVar("T") class CassetteContextDecorator: @@ -194,6 +196,7 @@ def __init__( # self.data is the list of (req, resp) tuples self.data = [] + self._metadata = {} self.play_counts = collections.Counter() self.dirty = False self.rewound = False @@ -324,6 +327,15 @@ def find_requests_with_most_matches(self, request): return final_best_matches + def get_metadata(self, key: str, default: T) -> T: + if key in self._metadata: + return self._metadata[key] + if self.write_protected: + raise CannotOverwriteExistingCassetteException(cassette=self, missing_metadata=key) + self.dirty = True + self._metadata[key] = default + return default + def _new_interactions(self): """List of new HTTP interactions (request/response tuples)""" new_interactions = [] @@ -336,15 +348,20 @@ def _new_interactions(self): return new_interactions def _as_dict(self): - return {"requests": self.requests, "responses": self.responses} + cassette_dict = {"requests": self.requests, "responses": self.responses} + if self._metadata: + cassette_dict["metadata"] = self._metadata + return cassette_dict def _build_used_interactions_dict(self): interactions = self._played_interactions + self._new_interactions() - cassete_dict = { + cassette_dict = { "requests": [request for request, _ in interactions], "responses": [response for _, response in interactions], } - return cassete_dict + if self._metadata: + cassette_dict["metadata"] = self._metadata + return cassette_dict def _save(self, force=False): if self.drop_unused_requests and len(self._played_interactions) < len(self._old_interactions): @@ -358,12 +375,18 @@ def _save(self, force=False): def _load(self): try: - requests, responses = self._persister.load_cassette(self._path, serializer=self._serializer) + loaded = self._persister.load_cassette(self._path, serializer=self._serializer) + if len(loaded) == 3: + requests, responses, metadata = loaded + else: + requests, responses = loaded + metadata = None for request, response in zip(requests, responses): self.append(request, response) self._old_interactions.append((request, response)) self.dirty = False self.rewound = True + self._metadata = metadata or {} except (CassetteDecodeError, CassetteNotFoundError): pass diff --git a/vcr/errors.py b/vcr/errors.py index 4072e5f7..133375ae 100644 --- a/vcr/errors.py +++ b/vcr/errors.py @@ -1,8 +1,15 @@ class CannotOverwriteExistingCassetteException(Exception): def __init__(self, *args, **kwargs): self.cassette = kwargs["cassette"] - self.failed_request = kwargs["failed_request"] - message = self._get_message(kwargs["cassette"], kwargs["failed_request"]) + self.failed_request = kwargs.get("failed_request") + self.missing_metadata = kwargs.get("missing_metadata") + message = None + if self.failed_request: + message = self._get_message(self.cassette, self.failed_request) + if self.missing_metadata: + message = f'Missing metadata key "{self.missing_metadata}"' + if not message: + raise ValueError("Invalid kwargs, failed_request or missing_metadata must be supplied") super().__init__(message) @staticmethod diff --git a/vcr/serialize.py b/vcr/serialize.py index 0eec2643..a01cd73b 100644 --- a/vcr/serialize.py +++ b/vcr/serialize.py @@ -44,6 +44,8 @@ def deserialize(cassette_string, serializer): requests = [Request._from_dict(r["request"]) for r in data["interactions"]] responses = [compat.convert_to_bytes(r["response"]) for r in data["interactions"]] + if data.get("metadata"): + return requests, responses, data["metadata"] return requests, responses @@ -56,4 +58,6 @@ def serialize(cassette_dict, serializer): for request, response in zip(cassette_dict["requests"], cassette_dict["responses"]) ] data = {"version": CASSETTE_FORMAT_VERSION, "interactions": interactions} + if cassette_dict.get("metadata"): + data["metadata"] = cassette_dict["metadata"] return serializer.serialize(data)