Skip to content

Commit

Permalink
[OM] Add API to iterate over an Object (#5402)
Browse files Browse the repository at this point in the history
Add API to get the field names from an object and add a convenient iterator.
  • Loading branch information
prithayan authored Jun 15, 2023
1 parent a343950 commit 1ecd327
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/circt-c/Dialect/OM.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMObject object);
MLIR_CAPI_EXPORTED OMObjectValue omEvaluatorObjectGetField(OMObject object,
MlirAttribute name);

/// Get all the field names from an Object, can be empty if object has no
/// fields.
MLIR_CAPI_EXPORTED MlirAttribute
omEvaluatorObjectGetFieldNames(OMObject object);

//===----------------------------------------------------------------------===//
// ObjectValue API.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ struct Object : std::enable_shared_from_this<Object> {
/// Get a field of the Object by name.
FailureOr<ObjectValue> getField(StringAttr name);

/// Get all the field names of the Object.
ArrayAttr getFieldNames();

private:
/// Allow the instantiate method as a friend to construct Objects.
friend FailureOr<std::shared_ptr<Object>>
Expand Down
5 changes: 5 additions & 0 deletions integration_test/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@
print(obj.field)
# CHECK: 14
print(obj.child.foo)

for (name, field) in obj:
# CHECK: name: child, field: <circt.dialects.om.Object object
# CHECK: name: field, field: 42
print(f"name: {name}, field: {field}")
15 changes: 15 additions & 0 deletions lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
Expand Down Expand Up @@ -55,6 +56,18 @@ struct Object {
return omEvaluatorObjectValueGetPrimitive(result);
}

// Get a list with the names of all the fields in the Object.
std::vector<std::string> getFieldNames() {
ArrayAttr fieldNames =
cast<ArrayAttr>(unwrap(omEvaluatorObjectGetFieldNames(object)));

std::vector<std::string> slots;
for (auto fieldName : fieldNames.getAsRange<StringAttr>())
slots.push_back(fieldName.str());

return slots;
}

private:
// The underlying CAPI OMObject.
OMObject object;
Expand Down Expand Up @@ -108,6 +121,8 @@ void circt::python::populateDialectOMSubmodule(py::module &m) {
.def(py::init<Object>(), py::arg("object"))
.def("__getattr__", &Object::getField, "Get a field from an Object",
py::arg("name"))
.def_property_readonly("field_names", &Object::getFieldNames,
"Get field names from an Object")
.def_property_readonly("type", &Object::getType,
"The Type of the Object");

Expand Down
5 changes: 5 additions & 0 deletions lib/Bindings/Python/dialects/om.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def __getattr__(self, name: str):
assert isinstance(field, BaseObject)
return Object(field)

# Support iterating over an Object by yielding its fields.
def __iter__(self):
for name in self.field_names:
yield (name, getattr(self, name))


# Define the Evaluator class by inheriting from the base implementation in C++.
class Evaluator(BaseEvaluator):
Expand Down
5 changes: 5 additions & 0 deletions lib/CAPI/Dialect/OM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ MlirType omEvaluatorObjectGetType(OMObject object) {
return wrap(unwrap(object)->getType());
}

/// Get an ArrayAttr with the names of the fields in an Object.
MlirAttribute omEvaluatorObjectGetFieldNames(OMObject object) {
return wrap(unwrap(object)->getFieldNames());
}

/// Get a field from an Object, which must contain a field of that name.
OMObjectValue omEvaluatorObjectGetField(OMObject object, MlirAttribute name) {
// Unwrap the Object and get the field of the name, which the client must
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,17 @@ FailureOr<ObjectValue> circt::om::Object::getField(StringAttr name) {
return cls.emitError("field ") << name << " does not exist";
return success(fields[name]);
}

/// Get an ArrayAttr with the names of the fields in the Object. Sort the fields
/// so there is always a stable order.
ArrayAttr circt::om::Object::getFieldNames() {
SmallVector<Attribute> fieldNames;
for (auto &f : fields)
fieldNames.push_back(f.first);

llvm::sort(fieldNames, [](Attribute a, Attribute b) {
return cast<StringAttr>(a).getValue() < cast<StringAttr>(b).getValue();
});

return ArrayAttr::get(cls.getContext(), fieldNames);
}
9 changes: 9 additions & 0 deletions test/CAPI/om.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ void testEvaluator(MlirContext ctx) {

OMObjectValue childField = omEvaluatorObjectGetField(object, childFieldName);

MlirAttribute fieldNamesO = omEvaluatorObjectGetFieldNames(object);
// CHECK: ["child", "field"]
mlirAttributeDump(fieldNamesO);

OMObject child = omEvaluatorObjectValueGetObject(childField);

// CHECK: 0
Expand All @@ -115,6 +119,11 @@ void testEvaluator(MlirContext ctx) {
OMObjectValue foo = omEvaluatorObjectGetField(
child, mlirStringAttrGet(ctx, mlirStringRefCreateFromCString("foo")));

MlirAttribute fieldNamesC = omEvaluatorObjectGetFieldNames(child);

// CHECK: ["foo"]
mlirAttributeDump(fieldNamesC);

// CHECK: child object field is primitive: 1
fprintf(stderr, "child object field is primitive: %d\n",
omEvaluatorObjectValueIsAPrimitive(foo));
Expand Down
13 changes: 13 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "gtest/gtest.h"
#include <mlir/IR/BuiltinAttributes.h>
Expand Down Expand Up @@ -378,6 +381,16 @@ TEST(EvaluatorTests, InstantiateObjectWithChildObjectMemoized) {
auto field2Value = std::get<std::shared_ptr<Object>>(
result.value()->getField(builder.getStringAttr("field2")).value());

auto fieldNames = result.value()->getFieldNames();

ASSERT_TRUE(fieldNames.size() == 2);
StringRef fieldNamesTruth[] = {"field1", "field2"};
for (auto fieldName : llvm::enumerate(fieldNames)) {
auto str = llvm::dyn_cast_or_null<StringAttr>(fieldName.value());
ASSERT_TRUE(str);
ASSERT_EQ(str.getValue(), fieldNamesTruth[fieldName.index()]);
}

ASSERT_TRUE(field1Value);
ASSERT_TRUE(field2Value);

Expand Down

0 comments on commit 1ecd327

Please sign in to comment.