Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions src/pytest_ansible_network_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .exceptions import PytestNetworkError
from .utils import _github_action_log
from .utils import _inventory
from .utils import _inventory_multi
from .utils import _print
from .utils import calculate_ports
from .utils import playbook
Expand Down Expand Up @@ -406,6 +407,138 @@ def _appliance_dhcp_address(env_vars: Dict[str, str]) -> Generator[str, None, No
_github_action_log("::endgroup::")


@pytest.fixture(scope="session", name="appliance_dhcp_map")
def _appliance_dhcp_map(env_vars: Dict[str, str]) -> Generator[Dict[str, str], None, None]:
"""Provision the lab and collect DHCP addresses for all appliances.

Returns a mapping of device name to DHCP IP for all devices in the lab.
"""
_github_action_log("::group::Starting lab provisioning (multi-device)")
_print("Starting lab provisioning (multi-device)")

try:
if not OPTIONS:
raise PytestNetworkError("Missing CML lab options")

lab_file = OPTIONS.cml_lab
if not os.path.exists(lab_file):
raise PytestNetworkError(f"Missing lab file '{lab_file}'")

start = time.time()
cml = CmlWrapper(
host=env_vars["cml_host"],
username=env_vars["cml_ui_user"],
password=env_vars["cml_ui_password"],
)
cml.bring_up(file=lab_file)
lab_id = cml.current_lab_id
logger.debug("Lab ID: %s", lab_id)

virsh = VirshWrapper(
host=env_vars["cml_host"],
user=env_vars["cml_ssh_user"],
password=env_vars["cml_ssh_password"],
port=int(env_vars["cml_ssh_port"]),
)

wait_extra_time = OPTIONS.wait_extra
wait_seconds = 0
if wait_extra_time:
try:
wait_seconds = int(wait_extra_time)
except ValueError:
logger.warning(
"Invalid wait_extra value: '%s'. Expected an integer. Skipping extra wait.",
wait_extra_time,
)
wait_seconds = 0

try:
device_to_ip = virsh.get_dhcp_leases(lab_id, wait_seconds)
except PytestNetworkError as exc:
logger.error("Failed to get DHCP leases for the appliances")
virsh.close()
cml.remove()
raise PytestNetworkError("Failed to get DHCP leases for the appliances") from exc

end = time.time()
elapsed = end - start
_print(f"Elapsed time to provision (multi): {elapsed} seconds")
logger.info("Elapsed time to provision (multi): %s seconds", elapsed)

except PytestNetworkError as exc:
logger.error("Failed to provision lab (multi): %s", exc)
_github_action_log("::endgroup::")
raise

finally:
virsh.close()
_github_action_log("::endgroup::")

yield device_to_ip

_github_action_log("::group::Removing lab (multi)")
try:
cml.remove()
except PytestNetworkError as exc:
logger.error("Failed to remove lab (multi): %s", exc)
raise
finally:
_github_action_log("::endgroup::")


@pytest.fixture
def ansible_project_multi(
appliance_dhcp_map: Dict[str, str],
env_vars: Dict[str, str],
integration_test_path: Path,
tmp_path: Path,
) -> AnsibleProject:
"""Build an Ansible project for all discovered appliances.

Creates a multi-host inventory using all DHCP leases discovered.
"""
logger.info("Building the Ansible project for multiple devices")

inventory = _inventory_multi(
host=env_vars["cml_host"],
device_to_ip=appliance_dhcp_map,
network_os=env_vars["network_os"],
username=env_vars["device_username"],
password=env_vars["device_password"],
)
logger.debug("Generated multi-host inventory: %s", inventory)

inventory_path = tmp_path / "inventory.json"
with inventory_path.open(mode="w", encoding="utf-8") as fh:
json.dump(inventory, fh)
logger.debug("Inventory written to %s", inventory_path)

playbook_contents = playbook(hosts="all", role=str(integration_test_path))
playbook_path = tmp_path / "site.json"
with playbook_path.open(mode="w", encoding="utf-8") as fh:
json.dump(playbook_contents, fh)
logger.debug("Playbook written to %s", playbook_path)

_print(f"Inventory path: {inventory_path}")
_print(f"Playbook path: {playbook_path}")

project = AnsibleProject(
collection_doc_cache=tmp_path / "collection_doc_cache.db",
directory=tmp_path,
inventory=inventory_path,
log_file=Path.home() / "test_logs" / f"{integration_test_path.name}.log",
playbook=playbook_path,
playbook_artifact=Path.home()
/ "test_logs"
/ "{playbook_status}"
/ f"{integration_test_path.name}.json",
role=integration_test_path.name,
)
logger.info("Ansible multi-host project created successfully")
return project


def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
"""Generate tests based on the integration test paths.

Expand Down
101 changes: 99 additions & 2 deletions src/pytest_ansible_network_integration/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,55 @@ def get_dhcp_lease(self, current_lab_id: str, wait_extra: int) -> str:
logger.info("Done waiting, starting to find IPs")

if len(ips) > 1:
logger.error("Found more than one IP: %s", ips)
logger.error("SSSSSSSSSSSSSSSSS Found more than one IP: %s", ips)
raise PytestNetworkError("Found more than one IP")

logger.info("DHCP lease IP found: %s", ips[0])
return ips[0]

def get_dhcp_leases(self, current_lab_id: str, wait_extra: int) -> Dict[str, str]:
"""Get DHCP leases for all devices in the specified lab.

:param current_lab_id: The current lab ID.
:param wait_extra: Extra seconds to wait before resolving leases.
:raises PytestNetworkError: If no leases can be found.
:return: Mapping of device name to its IP address.
"""
logger.info("Getting all current lab domains from virsh")
domains = self._find_current_lab_domains(current_lab_id, 20)

if wait_extra:
logger.info("Waiting for extra %s seconds before resolving leases", wait_extra)
time.sleep(wait_extra)

device_to_ip: Dict[str, str] = {}
for domain in domains:
try:
device_name = domain["domain"]["name"]
except KeyError as e:
logger.error("Failed to extract device name from domain: %s", e)
raise PytestNetworkError(f"Failed to extract device name: {e}") from e

macs = self._extract_macs(domain)
ips = self._find_dhcp_lease(macs, 200)

if not ips:
logger.error("No IP found for device '%s'", device_name)
raise PytestNetworkError(f"No IP found for device '{device_name}'")

if len(ips) > 1:
logger.warning(
"Multiple IPs found for device '%s' (MACs: %s), choosing first: %s",
device_name,
macs,
ips,
)

device_to_ip[device_name] = ips[0]

logger.info("Resolved DHCP leases for devices: %s", device_to_ip)
return device_to_ip

def _find_current_lab(self, current_lab_id: str, max_attempts: int = 20) -> Dict[str, Any]:
"""Find the current lab by its ID.

Expand Down Expand Up @@ -350,6 +393,60 @@ def _find_current_lab(self, current_lab_id: str, max_attempts: int = 20) -> Dict
logger.error("Could not find current lab after %s attempts", attempt)
raise PytestNetworkError("Could not find current lab")

def _find_current_lab_domains(
self, current_lab_id: str, max_attempts: int = 20
) -> List[Dict[str, Any]]:
"""Find all domains for the current lab by its ID.

Iterates over all virsh domains and collects those whose XML includes the
given lab ID. Retries up to max_attempts times.

:param current_lab_id: The current lab ID.
:param max_attempts: Maximum attempts to discover lab domains.
:raises PytestNetworkError: If no domains are found for the lab.
:return: A list of domain XML dicts.
"""
attempt = 0
while attempt < max_attempts:
logger.info("Attempt %s to find all current lab domains", attempt)
stdout, _stderr = self.ssh.execute("sudo virsh list --all")
logger.debug("virsh list output: %s", stdout)
if _stderr:
logger.error("virsh list stderr: %s", _stderr)

virsh_matches = [re.match(r"^\s(?P<id>\d+)", line) for line in stdout.splitlines()]
if not any(virsh_matches):
logger.error("No matching virsh IDs found in the output")
raise PytestNetworkError("No matching virsh IDs found")

try:
virsh_ids = [
virsh_match.groupdict()["id"] for virsh_match in virsh_matches if virsh_match
]
except KeyError as e:
error_message = f"Failed to extract virsh IDs: {e}"
logger.error(error_message)
raise PytestNetworkError(error_message) from e

matched_domains: List[Dict[str, Any]] = []
for virsh_id in virsh_ids:
stdout, _stderr = self.ssh.execute(f"sudo virsh dumpxml {virsh_id}")
if current_lab_id in stdout:
logger.debug(
"Found lab %s in virsh dumpxml for ID %s", current_lab_id, virsh_id
)
xmltodict_data = xmltodict.parse(stdout)
matched_domains.append(xmltodict_data) # type: ignore

if matched_domains:
return matched_domains

attempt += 1
time.sleep(5)

logger.error("Could not find any domains for current lab after %s attempts", attempt)
raise PytestNetworkError("Could not find any domains for current lab")

def _extract_macs(self, current_lab: Dict[str, Any]) -> List[str]:
"""Extract MAC addresses from the current lab.

Expand Down Expand Up @@ -408,7 +505,7 @@ def _find_dhcp_lease(self, macs: List[str], max_attempts: int = 100) -> List[str
return ips

attempt += 1
time.sleep(10)
time.sleep(400)

logger.error("Could not find IPs after %s attempts", attempt)
raise PytestNetworkError("Could not find IPs")
Expand Down
42 changes: 42 additions & 0 deletions src/pytest_ansible_network_integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping


def _print(message: str) -> None:
Expand Down Expand Up @@ -67,6 +68,47 @@ def _inventory(
return inventory


def _inventory_multi(
host: str,
device_to_ip: Mapping[str, str],
network_os: str,
username: str,
password: str,
) -> Dict[str, Any]:
"""Build an ansible inventory for multiple devices.

:param device_to_ip: Mapping of device name to its management IP
:param network_os: The network OS
:param username: Device username
:param password: Device password
:returns: The inventory for all devices under group 'all'
"""
hosts: Dict[str, Any] = {}

for device_name, ip_address in device_to_ip.items():
ports = calculate_ports(ip_address)
host_key = _sanitize_host_key(device_name)
hosts[host_key] = {
"ansible_become": False,
"ansible_host": host,
"ansible_user": username,
"ansible_password": password,
"ansible_port": ports["ssh_port"],
"ansible_httpapi_port": ports["http_port"],
"ansible_connection": "ansible.netcommon.network_cli",
"ansible_network_cli_ssh_type": "libssh",
"ansible_python_interpreter": "python",
"ansible_network_import_modules": True,
}

return {"all": {"hosts": hosts, "vars": {"ansible_network_os": network_os}}}


def _sanitize_host_key(name: str) -> str:
"""Return a safe inventory host key from an arbitrary device name."""
return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in name)


def playbook(hosts: str, role: str) -> List[Dict[str, object]]:
"""Return the playbook.

Expand Down