diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 827db5f3eba84..b13a429d4a3c0 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,12 +6,10 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" @@ -19,9 +17,14 @@ #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include namespace nb = nanobind; using namespace nb::literals; @@ -1329,11 +1332,11 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject, accum.getUserData()); } -void PyOperationBase::writeBytecode(const nb::object &fileObject, +void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); - PyFileAccumulator accum(fileObject, /*binary=*/true); + PyFileAccumulator accum(fileOrStringObject, /*binary=*/true); if (!bytecodeVersion.has_value()) return mlirOperationWriteBytecode(operation, accum.getCallback(), diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index ee193cf9f8ef8..64ea4329f65f1 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -13,8 +13,13 @@ #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include template <> struct std::iterator_traits { @@ -128,33 +133,59 @@ struct PyPrintAccumulator { } }; -/// Accumulates int a python file-like object, either writing text (default) -/// or binary. +/// Accumulates into a file, either writing text (default) +/// or binary. The file may be a Python file-like object or a path to a file. class PyFileAccumulator { public: - PyFileAccumulator(const nanobind::object &fileObject, bool binary) - : pyWriteFunction(fileObject.attr("write")), binary(binary) {} + PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary) + : binary(binary) { + std::string filePath; + if (nanobind::try_cast(fileOrStringObject, filePath)) { + std::error_code ec; + writeTarget.emplace(filePath, ec); + if (ec) { + throw nanobind::value_error( + (std::string("Unable to open file for writing: ") + ec.message()) + .c_str()); + } + } else { + writeTarget.emplace(fileOrStringObject.attr("write")); + } + } + + MlirStringCallback getCallback() { + return writeTarget.index() == 0 ? getPyWriteCallback() + : getOstreamCallback(); + } void *getUserData() { return this; } - MlirStringCallback getCallback() { +private: + MlirStringCallback getPyWriteCallback() { return [](MlirStringRef part, void *userData) { nanobind::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. nanobind::bytes pyBytes(part.data, part.length); - accum->pyWriteFunction(pyBytes); + std::get(accum->writeTarget)(pyBytes); } else { nanobind::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. - accum->pyWriteFunction(pyStr); + std::get(accum->writeTarget)(pyStr); } }; } -private: - nanobind::object pyWriteFunction; + MlirStringCallback getOstreamCallback() { + return [](MlirStringRef part, void *userData) { + PyFileAccumulator *accum = static_cast(userData); + std::get(accum->writeTarget) + .write(part.data, part.length); + }; + } + + std::variant writeTarget; bool binary; }; diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index ab975a6954044..c93de2fe3154e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -47,7 +47,7 @@ import collections from collections.abc import Callable, Sequence import io from pathlib import Path -from typing import Any, ClassVar, TypeVar, overload +from typing import Any, BinaryIO, ClassVar, TypeVar, overload __all__ = [ "AffineAddExpr", @@ -285,12 +285,12 @@ class _OperationBase: """ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. """ - def write_bytecode(self, file: Any, desired_version: int | None = None) -> None: + def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None: """ Write the bytecode form of the operation to a file like object. Args: - file: The file like object to write to. + file: The file like object or path to write to. desired_version: The version of bytecode to emit. Returns: The bytecode writer status. diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index c2d3aed8808b4..090d0030fb062 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -3,6 +3,7 @@ import gc import io import itertools +from tempfile import NamedTemporaryFile from mlir.ir import * from mlir.dialects.builtin import ModuleOp from mlir.dialects import arith @@ -617,6 +618,12 @@ def testOperationPrint(): module.operation.write_bytecode(bytecode_stream, desired_version=1) bytecode = bytecode_stream.getvalue() assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR" + with NamedTemporaryFile() as tmpfile: + module.operation.write_bytecode(str(tmpfile.name), desired_version=1) + tmpfile.seek(0) + assert tmpfile.read().startswith( + b"ML\xefR" + ), "Expected bytecode to start with MLïR" ctx2 = Context() module_roundtrip = Module.parse(bytecode, ctx2) f = io.StringIO()