Skip to content

Add an API to describe variant using labels #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from variantlib.base import PluginBase
from variantlib.config import KeyConfig, ProviderConfig
from variantlib.meta import VariantDescription, VariantMeta
from variantlib.plugins import PluginLoader


Expand All @@ -20,6 +21,12 @@ def get_supported_configs(self) -> Optional[ProviderConfig]:
],
)

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
for meta in variant_desc:
if meta.namespace == self.namespace and meta.key == "key1":
return [meta.value.removeprefix("val")]
return []


# NB: this plugin deliberately does not inherit from PluginBase
# to test that we don't rely on that inheritance
Expand All @@ -34,9 +41,28 @@ def get_supported_configs(self) -> Optional[ProviderConfig]:
],
)

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
if VariantMeta(self.namespace, "key3", "val3a") in variant_desc:
return ["sec"]
return []


class MockedPluginC(PluginBase):
namespace = "incompatible_plugin"
namespace = "other_plugin"

def get_supported_configs(self) -> Optional[ProviderConfig]:
return None

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
ret = []
for meta in variant_desc:
if meta.namespace == self.namespace and meta.value == "on":
ret.append(meta.key)
return ret


class MockedPluginD:
namespace = "plugin_without_labels"

def get_supported_configs(self) -> Optional[ProviderConfig]:
return None
Expand All @@ -48,6 +74,13 @@ class ClashingPlugin(PluginBase):
def get_supported_configs(self) -> Optional[ProviderConfig]:
return None

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
ret = []
for meta in variant_desc:
if meta.namespace == self.namespace and meta.value == "on":
ret.append(meta.key)
return ret


@dataclass
class MockedDistribution:
Expand Down Expand Up @@ -87,6 +120,11 @@ def mocked_plugin_loader(session_mocker):
value="tests.test_plugins:MockedPluginC",
plugin=MockedPluginC,
),
MockedEntryPoint(
name="no_labels",
value="tests.test_plugins:MockedPluginD",
plugin=MockedPluginD,
),
]
yield PluginLoader()

Expand Down Expand Up @@ -137,3 +175,36 @@ def test_namespace_clash(mocker):
assert "same namespace test_plugin" in str(exc)
assert "test-plugin" in str(exc)
assert "clashing-plugin" in str(exc)


@pytest.mark.parametrize("variant_desc,expected",
[
(VariantDescription([
VariantMeta("test_plugin", "key1", "val1a"),
VariantMeta("test_plugin", "key2", "val2b"),
VariantMeta("second_plugin", "key3", "val3a"),
VariantMeta("other_plugin", "flag2", "on"),
]), ["1a", "sec", "flag2"]),
(VariantDescription([
# note that VariantMetas don't actually have to be supported
# by the system in question -- we could be cross-building
# for another system
VariantMeta("test_plugin", "key1", "val1f"),
VariantMeta("test_plugin", "key2", "val2b"),
VariantMeta("second_plugin", "key3", "val3a"),
]), ["1f", "sec"]),
(VariantDescription([
VariantMeta("test_plugin", "key2", "val2b"),
VariantMeta("second_plugin", "key3", "val3a"),
]), ["sec"]),
(VariantDescription([
VariantMeta("test_plugin", "key2", "val2b"),
]), []),
(VariantDescription([
VariantMeta("test_plugin", "key2", "val2b"),
VariantMeta("other_plugin", "flag1", "on"),
VariantMeta("other_plugin", "flag2", "on"),
]), ["flag1", "flag2"]),
])
def test_get_variant_labels(mocked_plugin_loader, variant_desc, expected):
assert mocked_plugin_loader.get_variant_labels(variant_desc) == expected
9 changes: 8 additions & 1 deletion variantlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Protocol, runtime_checkable

from variantlib.config import ProviderConfig
from variantlib.meta import VariantDescription


@runtime_checkable
class PluginType(Protocol):
"""A protocol for plugin classes"""

Expand All @@ -17,6 +17,10 @@ def get_supported_configs(self) -> ProviderConfig:
"""Get supported configs for the current system"""
...

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
"""Get list of short labels to describe the variant"""
...


class PluginBase(ABC):
"""An abstract base class that can be used to implement plugins"""
Expand All @@ -26,3 +30,6 @@ def namespace(self) -> str: ...

@abstractmethod
def get_supported_configs(self) -> ProviderConfig: ...

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
return []
17 changes: 16 additions & 1 deletion variantlib/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def load_plugins(self) -> None:

# Instantiate the plugin
plugin_instance = plugin_class()
assert isinstance(plugin_instance, PluginType)

# Check for obligatory members
for attr in ("namespace", "get_supported_configs"):
assert hasattr(
plugin_instance, attr
), f"Plugin is missing required member: {attr}"
except Exception:
logging.exception("An unknown error happened - Ignoring plugin")
else:
Expand Down Expand Up @@ -92,3 +97,13 @@ def get_dist_name_mapping(self) -> dict[str, str]:
"""Get a mapping from plugin names to distribution names"""

return self._dist_names

def get_variant_labels(self, variant_desc: VariantDescription) -> list[str]:
"""Get list of short labels to describe the variant"""

labels = []
for plugin in self._plugins.values():
if hasattr(plugin, "get_variant_labels"):
labels += plugin.get_variant_labels(variant_desc)

return labels
Loading