Skip to content

Commit

Permalink
chg: return nested galaxy entities, optimize insertion. fixes #45
Browse files Browse the repository at this point in the history
  • Loading branch information
righel committed Dec 23, 2024
1 parent 00dea14 commit 0682fce
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 44 deletions.
8 changes: 8 additions & 0 deletions api/app/models/galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class GalaxyCluster(Base):
extends_version = Column(Integer, nullable=True)
published = Column(Boolean, nullable=False, default=False)
deleted = Column(Boolean, nullable=False, default=False)
elements = relationship("GalaxyElement", lazy="subquery")
relations = relationship(
"GalaxyClusterRelation",
lazy="subquery",
foreign_keys="[GalaxyClusterRelation.galaxy_cluster_id]",
)


class GalaxyElement(Base):
Expand Down Expand Up @@ -108,6 +114,7 @@ class GalaxyClusterRelation(Base):
Integer, ForeignKey("sharing_groups.id"), index=True, nullable=True
)
default = Column(Boolean, nullable=False, default=False)
tags = relationship("GalaxyClusterRelationTag", lazy="subquery")


class GalaxyClusterRelationTag(Base):
Expand All @@ -117,3 +124,4 @@ class GalaxyClusterRelationTag(Base):
Integer, ForeignKey("galaxy_cluster_relations.id"), index=True, nullable=False
)
tag_id = Column(Integer, ForeignKey("tags.id"), index=True, nullable=False)
tag = relationship("Tag", lazy="subquery")
85 changes: 58 additions & 27 deletions api/app/repositories/galaxies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import os
from datetime import datetime
from uuid import UUID

from app.models import event as events_models
from app.models import galaxy as galaxies_models
Expand All @@ -11,6 +13,8 @@
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)


def get_galaxies(db: Session, filter: str = Query(None)) -> galaxies_models.Galaxy:
query = db.query(galaxies_models.Galaxy)
Expand Down Expand Up @@ -70,9 +74,6 @@ def update_galaxies(
created=datetime.now(),
modified=datetime.now(),
)
db.add(galaxy)
db.commit()
db.refresh(galaxy)

# parse galaxy clusters file
with open(os.path.join(galaxies_clusters_dir, galaxy_file)) as f:
Expand All @@ -81,7 +82,6 @@ def update_galaxies(
if "values" in clusters_data_raw:
for cluster in clusters_data_raw["values"]:
galaxy_cluster = galaxies_models.GalaxyCluster(
galaxy_id=galaxy.id,
uuid=cluster["uuid"],
value=cluster["value"],
type=(
Expand Down Expand Up @@ -124,25 +124,20 @@ def update_galaxies(
else None
),
)
db.add(galaxy_cluster)
db.flush()
galaxy.clusters.append(galaxy_cluster)

# add galaxy elements
if "meta" in cluster:
for element in cluster["meta"]:
galaxy_element = galaxies_models.GalaxyElement(
galaxy_cluster_id=galaxy_cluster.id,
key=element,
value=(
cluster["meta"][element]
if isinstance(cluster["meta"][element], str)
else json.dumps(cluster["meta"][element])
),
)
db.add(galaxy_element)

# commit galaxy elements
db.commit()
galaxy_cluster.elements.append(galaxy_element)

# add galaxy relations
if "related" in cluster:
Expand All @@ -153,10 +148,20 @@ def update_galaxies(
"dest-uuid" not in relation
or not relation["dest-uuid"]
):
logger.warning(
f"Missing dest-uuid {relation['dest-uuid']} for galaxy {galaxy.name}"
)
continue

try:
UUID(relation["dest-uuid"])
except ValueError:
logger.warning(
f"Invalid dest-uuid {relation['dest-uuid']} for galaxy {galaxy.name}"
)
continue

galaxy_relation = galaxies_models.GalaxyClusterRelation(
galaxy_cluster_id=galaxy_cluster.id,
galaxy_cluster_uuid=cluster["uuid"],
referenced_galaxy_cluster_uuid=relation[
"dest-uuid"
Expand All @@ -165,30 +170,49 @@ def update_galaxies(
default=True,
distribution=events_models.DistributionLevel.ALL_COMMUNITIES,
)
db.add(galaxy_relation)
db.flush()

if "tags" in relation:
for tag in relation["tags"]:
for related_tag in relation["tags"]:
tag = tags_repository.get_tag_by_name(
db, tag_name=tag
db, tag_name=related_tag
)

if tag:
galaxy_relation_tag = galaxies_models.GalaxyClusterRelationTag(
galaxy_cluster_relation_id=galaxy_relation.id,
tag=tag,
if not tag:
logger.warning(
f"Tag {related_tag} not found for galaxy {galaxy.name}"
)
tag = tags_repository.create_tag(
db,
tag=tags_repository.tag_schemas.TagCreate(
name=related_tag,
colour="#000000",
exportable=False,
org_id=user.org_id,
user_id=user.id,
hide_tag=False,
is_galaxy=False,
is_custom_galaxy=False,
local_only=False,
),
)

db.add(galaxy_relation_tag)

# commit galaxy relations and tags
db.commit()
galaxy_relation_tag = galaxies_models.GalaxyClusterRelationTag(
tag=tag,
)

# commit galaxy clusters
db.commit()
galaxy_relation.tags.append(
galaxy_relation_tag
)

galaxies.append(galaxy)
galaxy_cluster.relations.append(galaxy_relation)
try:
db.add(galaxy)
db.commit()
db.refresh(galaxy)
galaxies.append(galaxy)
except Exception as e:
logger.error(f"Error creating galaxy {galaxy.name}: {e}")
db.rollback()

# fix galaxy cluster relations references to galaxy clusters
relations = db.query(galaxies_models.GalaxyClusterRelation).all()
Expand All @@ -201,6 +225,13 @@ def update_galaxies(
)
.first()
)

if not galaxy_cluster:
logger.warning(
f"Galaxy cluster {relation.referenced_galaxy_cluster_uuid} not found"
)
continue

relation.referenced_galaxy_cluster_id = galaxy_cluster.id
db.add(relation)
db.commit()
Expand Down
85 changes: 68 additions & 17 deletions api/app/schemas/galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,73 @@
from pydantic import BaseModel, ConfigDict


class GalaxyClusterRelationTagBase(BaseModel):
galaxy_cluster_relation_id: int
tag_id: int


class GalaxyClusterRelationTag(GalaxyClusterRelationTagBase):
galaxy_cluster_relation_id: int
tag_id: int
model_config = ConfigDict(from_attributes=True)


class GalaxyClusterRelationBase(BaseModel):
galaxy_cluster_id: int
referenced_galaxy_cluster_id: Optional[int] = None
referenced_galaxy_cluster_uuid: Optional[UUID] = None
referenced_galaxy_cluster_type: str
galaxy_cluster_uuid: Optional[UUID] = None
distribution: DistributionLevel
sharing_group_id: Optional[int] = None
default: bool


class GalaxyClusterRelation(GalaxyClusterRelationBase):
id: int
tags: list[GalaxyClusterRelationTag] = []
model_config = ConfigDict(from_attributes=True)


class GalaxyElementBase(BaseModel):
key: str
value: str
galaxy_cluster_id: int


class GalaxyElement(GalaxyElementBase):
id: int
model_config = ConfigDict(from_attributes=True)


class GalaxyClusterBase(BaseModel):
uuid: UUID
collection_uuid: Optional[UUID] = None
type: str
value: str
tag_name: str
description: str
galaxy_id: int
source: str
authors: Optional[list] = []
version: Optional[int] = None
distribution: DistributionLevel
sharing_group_id: Optional[int] = None
org_id: int
orgc_id: int
extends_uuid: Optional[UUID] = None
extends_version: Optional[int] = None
published: bool
deleted: bool


class GalaxyCluster(GalaxyClusterBase):
id: int
relations: list[GalaxyClusterRelation] = []
elements: list[GalaxyElement] = []
model_config = ConfigDict(from_attributes=True)


class GalaxyBase(BaseModel):
uuid: Optional[UUID] = None
name: str
Expand All @@ -27,27 +94,11 @@ class GalaxyBase(BaseModel):

class Galaxy(GalaxyBase):
id: int
clusters: list[GalaxyCluster] = []
model_config = ConfigDict(from_attributes=True)


class GalaxyUpdate(BaseModel):
default: Optional[bool] = None
enabled: Optional[bool] = None
local_only: Optional[bool] = None


class GalaxyClusterBase(BaseModel):
name: str
description: str
version: int
icon: str
namespace: str
enabled: bool
local_only: bool
kill_chain_order: Optional[dict] = {}
default: bool
org_id: int
orgc_id: int
created: datetime
modified: datetime
distribution: DistributionLevel

0 comments on commit 0682fce

Please sign in to comment.