[fix] mooncake: unpack dicts containing tensors to avoid bytes-pool f…#106
[fix] mooncake: unpack dicts containing tensors to avoid bytes-pool f…#106xupinjie wants to merge 1 commit into
Conversation
CLA Signature Guide@xupinjie , thanks for your pull request. The following commit(s) are not associated with a signed Contributor License Agreement (CLA).
To sign CLA, click here. To check if your email is configured correctly, refer to the FAQs. Once you've signed the CLA or updating your email, please comment |
| except ImportError: | ||
| MOONCAKE_STORE_IMPORTED = False | ||
|
|
||
| from tensordict import NonTensorData as _NonTensorData |
There was a problem hiding this comment.
Why we need this rename?
There was a problem hiding this comment.
Pull request overview
This PR fixes a MooncakeStore failure mode where storing dict values containing tensors can route data through Mooncake’s bytes pool, which may return b"" under high concurrent GET pressure and crash training. The client now “unpacks” dicts-with-tensors into multiple synthetic sub-keys so tensor payloads always use the tensor RDMA path, and adds tests to validate round-trip behavior and metadata handling.
Changes:
- Add dict-with-tensor fan-out in
MooncakeStoreClient.put()and corresponding re-folding logic inget()using per-keycustom_backend_meta. - Add expanded-key deletion logic in
clear()to remove dict sub-keys (including the bundled extras blob). - Add a comprehensive new test suite covering helper logic, end-to-end round-trip via a fake Mooncake store, and metadata serialization behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
transfer_queue/storage/clients/mooncake_client.py |
Implements dict unpacking into tensor-path sub-keys, reconstruction in get(), and expanded deletion in clear(). |
tests/test_mooncake_dict_unpack.py |
Adds unit + end-to-end tests validating dict unpack/repack, metadata shape, and clear/get behaviors. |
Comments suppressed due to low confidence (3)
transfer_queue/storage/clients/mooncake_client.py:52
- The synthetic sub-key scheme reserves
__tq_extras__for the bundled non-tensor blob. If an input dict contains a tensor entry whose key equals this reserved name, it will collide with the extras-blob key ({key}::__tq_extras__) and overwrite/corrupt data. Add validation/escaping for dict keys (at least reject reserved names like__tq_extras__and possibly also guard against_TQ_DICT_UNPACK_KEY).
# Separator joining an original key to a dict sub-key (e.g. "5@mmi::pixel_values").
_DICT_SUBKEY_SEP: str = "::"
# Sentinel marker key identifying a per-key dict-unpack meta entry.
_TQ_DICT_UNPACK_KEY: str = "__tq_dict_unpack__"
# Reserved sub-key name for the bundled non-tensor blob (a 1D uint8 tensor that
# carries pickle bytes of all non-tensor entries of the original dict).
_TQ_EXTRAS_SUBKEY: str = "__tq_extras__"
transfer_queue/storage/clients/mooncake_client.py:241
custom_backend_metacurrently storestorch.dtypeobjects (seetensor_dtypes).transfer_queue.utils.serial_utils.encode()explicitly falls back to pickle when payloads containtorch.dtype, which can negate the intended msgpack round-trip and add overhead. Consider storing dtypes as simple msgpack-native values (e.g., dtype name strings) and converting back totorch.dtypeinget.
custom_meta[i] = {
_TQ_DICT_UNPACK_KEY: True,
"key_order": key_order,
"tensor_keys": ts_sub_keys,
"tensor_dtypes": [t.dtype for t in ts_sub_tensors],
"tensor_shapes": [list(t.shape) for t in ts_sub_tensors],
"extras_size": extras_size,
tests/test_mooncake_dict_unpack.py:605
- This test’s rationale claims the dict-unpack meta survives a msgspec/msgpack controller round-trip because it’s a plain dict. However, the meta includes
torch.dtypevalues, andserial_utils.encode()documents that it falls back to pickle whentorch.dtypeis present. Either adjust the explanation/expectations, or store msgpack-native dtype representations (e.g., strings/ints) so the round-trip is truly msgpack-based.
def test_meta_survives_tq_msgpack_pipeline(self):
"""REGRESSION: an earlier implementation made the dict-unpack meta a
``@dataclass``, which msgspec auto-flattened into a typeless dict on
the controller round-trip; ``isinstance`` checks then failed at GET
and the bytes-pool fallback re-triggered the original bug. Using a
plain ``dict`` with a sentinel key sidesteps the issue — dicts are a
native msgpack map type, so the structure (including the
``_TQ_DICT_UNPACK_KEY`` marker and all fields) round-trips
losslessly.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| except ImportError: | ||
| MOONCAKE_STORE_IMPORTED = False | ||
|
|
||
| from tensordict import NonTensorData as _NonTensorData |
| def test_non_tensor_data_wrapped_dict_is_true(self): | ||
| """The KV storage manager hands the client NonTensorData-wrapped dicts; | ||
| the dict-unpack path must unwrap them before classification.""" | ||
| try: | ||
| from tensordict import NonTensorData | ||
| except ImportError: | ||
| pytest.skip("tensordict not installed in this env") | ||
| v = NonTensorData({"a": torch.zeros(3), "b": torch.ones(2, 4)}) | ||
| assert _dict_has_tensor(v) |
| ) | ||
|
|
||
|
|
||
| def _expand_dict_slots_fn( |
There was a problem hiding this comment.
_flatten_dict_slots or _unpack_dict_slots?
| extras_idx = -1 | ||
| extras_size = meta.get("extras_size", 0) | ||
| if extras_size > 0: | ||
| flat_keys.append(f"{key}{_DICT_SUBKEY_SEP}{_TQ_EXTRAS_SUBKEY}") | ||
| flat_shapes.append([extras_size]) | ||
| flat_dtypes.append(torch.uint8) | ||
| extras_idx = len(flat_keys) - 1 |
There was a problem hiding this comment.
Need some comments to explain here
| flat_shapes.append(shapes[i]) | ||
| flat_dtypes.append(dtypes[i]) | ||
| reconstruct.append(("scalar", len(flat_keys) - 1)) | ||
| return flat_keys, flat_shapes, flat_dtypes, reconstruct |
There was a problem hiding this comment.
reconstruct -> rebuild_plan?
| if len(keys) != len(values): | ||
| raise ValueError("Number of keys must match number of values") | ||
|
|
||
| custom_meta: list[Any] = [None] * len(keys) |
There was a problem hiding this comment.
custom_meta -> custom_backend_meta because the structures are different for these two types of meta
There was a problem hiding this comment.
The first one is per-sample while the second one is per-sample-per-field
| # Dict-with-tensor fan-out: avoid the Mooncake bytes pool which | ||
| # silently returns b"" under MB-scale GET pressure (see | ||
| # real_client.cpp:2209 "Failed to allocate buffer"). Each |
There was a problem hiding this comment.
The background info is not needed.
| ``shapes`` and ``dtypes`` describe the expected tensor layout per key | ||
| (use ``None`` for non-tensor slots). ``custom_backend_meta`` carries | ||
| per-key metadata returned by ``put``. Returns values in input order. |
There was a problem hiding this comment.
Suggest using previous format for docstring
|
In general, I think the current implementation will be difficult to maintain. I plan to move the serialization logic in the Yuanrong client (https://github.com/Ascend/TransferQueue/blob/main/transfer_queue/storage/clients/yuanrong_client.py#L288) to a higher level. This will allow it to be shared by all storage backends, utilizing the common |
Bug:
Storing dict values that contain tensors (e.g. Qwen3-VL's multi_modal_inputs) routes them through Mooncake's bytes pool, which silently returns b"" under MB-scale concurrent GET pressure and crashes training.
Fix:
The client now splits such dicts so each sub-tensor rides the working RDMA tensor path under a synthetic sub-key (any non-tensor entries are pickled into one uint8 blob that also rides RDMA), so the buggy bytes pool is never touched.
Res:
