From 32f7c62b66ad6af19e18f4fe381e8bf3595b55b4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 31 Mar 2025 20:16:39 +0200 Subject: [PATCH] PEP 782: Add PyBytesWriter C API --- pythoncapi_compat.h | 250 ++++++++++++++++++++++++++++ tests/test_pythoncapi_compat_cext.c | 98 +++++++++++ 2 files changed, 348 insertions(+) diff --git a/pythoncapi_compat.h b/pythoncapi_compat.h index 4b179e4..bf7eb30 100644 --- a/pythoncapi_compat.h +++ b/pythoncapi_compat.h @@ -2199,6 +2199,256 @@ PyConfig_GetInt(const char *name, int *value) #endif // PY_VERSION_HEX > 0x03090000 && !defined(PYPY_VERSION) +#if PY_VERSION_HEX < 0x030E00A7 +typedef struct PyBytesWriter { + char small_buffer[256]; + PyObject *obj; + Py_ssize_t size; +} PyBytesWriter; + +static inline Py_ssize_t +_PyBytesWriter_GetAllocated(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return sizeof(writer->small_buffer); + } + else { + return PyBytes_GET_SIZE(writer->obj); + } +} + + +static inline int +_PyBytesWriter_Resize_impl(PyBytesWriter *writer, Py_ssize_t size, + int overallocate) +{ + assert(size >= 0); + + if (size <= _PyBytesWriter_GetAllocated(writer)) { + return 0; + } + + if (overallocate) { +#ifdef MS_WINDOWS + /* On Windows, overallocate by 50% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 2)) { + size += size / 2; + } +#else + /* On Linux, overallocate by 25% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 4)) { + size += size / 4; + } +#endif + } + + if (writer->obj != NULL) { + if (_PyBytes_Resize(&writer->obj, size)) { + return -1; + } + assert(writer->obj != NULL); + } + else { + writer->obj = PyBytes_FromStringAndSize(NULL, size); + if (writer->obj == NULL) { + return -1; + } + assert((size_t)size > sizeof(writer->small_buffer)); + memcpy(PyBytes_AS_STRING(writer->obj), + writer->small_buffer, + sizeof(writer->small_buffer)); + } + return 0; +} + +static inline void* +PyBytesWriter_GetData(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return writer->small_buffer; + } + else { + return PyBytes_AS_STRING(writer->obj); + } +} + +static inline Py_ssize_t +PyBytesWriter_GetSize(PyBytesWriter *writer) +{ + return writer->size; +} + +static inline void +PyBytesWriter_Discard(PyBytesWriter *writer) +{ + if (writer == NULL) { + return; + } + + Py_XDECREF(writer->obj); + PyMem_Free(writer); +} + +static inline PyBytesWriter* +PyBytesWriter_Create(Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return NULL; + } + + PyBytesWriter *writer = (PyBytesWriter*)PyMem_Malloc(sizeof(PyBytesWriter)); + if (writer == NULL) { + PyErr_NoMemory(); + return NULL; + } + + writer->obj = NULL; + writer->size = 0; + + if (size >= 1) { + if (_PyBytesWriter_Resize_impl(writer, size, 0) < 0) { + PyBytesWriter_Discard(writer); + return NULL; + } + writer->size = size; + } + return writer; +} + +static inline PyObject* +PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size) +{ + PyObject *result; + if (size == 0) { + result = PyBytes_FromStringAndSize("", 0); + } + else if (writer->obj != NULL) { + if (size != PyBytes_GET_SIZE(writer->obj)) { + if (_PyBytes_Resize(&writer->obj, size)) { + goto error; + } + } + result = writer->obj; + writer->obj = NULL; + } + else { + result = PyBytes_FromStringAndSize(writer->small_buffer, size); + } + PyBytesWriter_Discard(writer); + return result; + +error: + PyBytesWriter_Discard(writer); + return NULL; +} + +static inline PyObject* +PyBytesWriter_Finish(PyBytesWriter *writer) +{ + return PyBytesWriter_FinishWithSize(writer, writer->size); +} + +static inline PyObject* +PyBytesWriter_FinishWithPointer(PyBytesWriter *writer, void *buf) +{ + Py_ssize_t size = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (size < 0 || size > _PyBytesWriter_GetAllocated(writer)) { + PyBytesWriter_Discard(writer); + PyErr_SetString(PyExc_ValueError, "invalid end pointer"); + return NULL; + } + + return PyBytesWriter_FinishWithSize(writer, size); +} + +static inline int +PyBytesWriter_Resize(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return -1; + } + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline int +PyBytesWriter_Grow(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0 && writer->size + size < 0) { + PyErr_SetString(PyExc_ValueError, "invalid size"); + return -1; + } + if (size > PY_SSIZE_T_MAX - writer->size) { + PyErr_NoMemory(); + return -1; + } + size = writer->size + size; + + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline void* +PyBytesWriter_GrowAndUpdatePointer(PyBytesWriter *writer, + Py_ssize_t size, void *buf) +{ + Py_ssize_t pos = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (PyBytesWriter_Grow(writer, size) < 0) { + return NULL; + } + return (char*)PyBytesWriter_GetData(writer) + pos; +} + +static inline int +PyBytesWriter_WriteBytes(PyBytesWriter *writer, + const void *bytes, Py_ssize_t size) +{ + if (size < 0) { + size_t len = strlen((const char*)bytes); + if (len > (size_t)PY_SSIZE_T_MAX) { + PyErr_NoMemory(); + return -1; + } + size = (Py_ssize_t)len; + } + + Py_ssize_t pos = writer->size; + if (PyBytesWriter_Grow(writer, size) < 0) { + return -1; + } + char *buf = (char*)PyBytesWriter_GetData(writer); + memcpy(buf + pos, bytes, (size_t)size); + return 0; +} + +static inline int +PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) +{ + va_list vargs; + va_start(vargs, format); + PyObject *str = PyBytes_FromFormatV(format, vargs); + va_end(vargs); + + if (str == NULL) { + return -1; + } + int res = PyBytesWriter_WriteBytes(writer, + PyBytes_AS_STRING(str), + PyBytes_GET_SIZE(str)); + Py_DECREF(str); + return res; +} +#endif // PY_VERSION_HEX < 0x030E00A7 + + #ifdef __cplusplus } #endif diff --git a/tests/test_pythoncapi_compat_cext.c b/tests/test_pythoncapi_compat_cext.c index 9d324a4..05db957 100644 --- a/tests/test_pythoncapi_compat_cext.c +++ b/tests/test_pythoncapi_compat_cext.c @@ -2181,6 +2181,103 @@ test_config(PyObject *Py_UNUSED(module), PyObject *Py_UNUSED(args)) #endif +static int +test_byteswriter_highlevel(void) +{ + PyObject *obj; + PyBytesWriter *writer = PyBytesWriter_Create(0); + if (writer == NULL) { + goto error; + } + if (PyBytesWriter_WriteBytes(writer, "Hello", -1) < 0) { + goto error; + } + if (PyBytesWriter_Format(writer, " %s!", "World") < 0) { + goto error; + } + + obj = PyBytesWriter_Finish(writer); + if (obj == NULL) { + return -1; + } + assert(PyBytes_Check(obj)); + assert(strcmp(PyBytes_AS_STRING(obj), "Hello World!") == 0); + Py_DECREF(obj); + return 0; + +error: + PyBytesWriter_Discard(writer); + return -1; +} + +static int +test_byteswriter_abc(void) +{ + PyBytesWriter *writer = PyBytesWriter_Create(3); + if (writer == NULL) { + return -1; + } + + char *str = (char*)PyBytesWriter_GetData(writer); + memcpy(str, "abc", 3); + + PyObject *obj = PyBytesWriter_Finish(writer); + if (obj == NULL) { + return -1; + } + assert(PyBytes_Check(obj)); + assert(strcmp(PyBytes_AS_STRING(obj), "abc") == 0); + Py_DECREF(obj); + return 0; +} + +static int +test_byteswriter_grow(void) +{ + PyBytesWriter *writer = PyBytesWriter_Create(10); + if (writer == NULL) { + return -1; + } + + char *buf = (char*)PyBytesWriter_GetData(writer); + memcpy(buf, "Hello ", strlen("Hello ")); + buf += strlen("Hello "); + + buf = (char*)PyBytesWriter_GrowAndUpdatePointer(writer, 10, buf); + if (buf == NULL) { + PyBytesWriter_Discard(writer); + return -1; + } + + memcpy(buf, "World", strlen("World")); + buf += strlen("World"); + + PyObject *obj = PyBytesWriter_FinishWithPointer(writer, buf); + if (obj == NULL) { + return -1; + } + assert(PyBytes_Check(obj)); + assert(strcmp(PyBytes_AS_STRING(obj), "Hello World") == 0); + Py_DECREF(obj); + return 0; +} + +static PyObject * +test_byteswriter(PyObject *Py_UNUSED(module), PyObject *Py_UNUSED(args)) +{ + if (test_byteswriter_highlevel() < 0) { + return NULL; + } + if (test_byteswriter_abc() < 0) { + return NULL; + } + if (test_byteswriter_grow() < 0) { + return NULL; + } + Py_RETURN_NONE; +} + + static struct PyMethodDef methods[] = { {"test_object", test_object, METH_NOARGS, _Py_NULL}, {"test_py_is", test_py_is, METH_NOARGS, _Py_NULL}, @@ -2232,6 +2329,7 @@ static struct PyMethodDef methods[] = { #if 0x03090000 <= PY_VERSION_HEX && !defined(PYPY_VERSION) {"test_config", test_config, METH_NOARGS, _Py_NULL}, #endif + {"test_byteswriter", test_byteswriter, METH_NOARGS, _Py_NULL}, {_Py_NULL, _Py_NULL, 0, _Py_NULL} };