Skip to content

Commit 7c12901

Browse files
authored
Merge pull request #25 from GNS-Science/feature/24_gmm_logic_tree_classes
Feature/24 gmm logic tree classes
2 parents ad97850 + 7b65d8d commit 7c12901

8 files changed

+599
-209
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33

4-
## [0.2.0] - 2024-05-30
4+
## [0.2.0] - 2024-06-07
55
### Changed
66
- Complete reset, no more django
77
- all previous code is mothballed
@@ -13,6 +13,7 @@
1313
- get_model resolver
1414
- get_models resolver
1515
- source logic tree models and resolvers
16+
- gmmm logic tree models and resolvers
1617

1718
## [0.1.3] - 2023-09-04
1819
### Added
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Define graphene model for nzshm_model gmm logic tree classes."""
2+
3+
import json
4+
import logging
5+
from functools import lru_cache
6+
7+
import graphene
8+
from graphene import relay
9+
10+
from .nshm_model_sources_schema import get_model_by_version
11+
12+
log = logging.getLogger(__name__)
13+
14+
15+
# TODO: this method belongs on the nzshm-model gmcm class
16+
@lru_cache
17+
def get_branch_set(model_version, short_name):
18+
glt = get_model_by_version(model_version).gmm_logic_tree
19+
log.debug(f"glt {glt}")
20+
for bs in glt.branch_sets:
21+
if bs.short_name == short_name:
22+
return bs
23+
assert 0, f"branch set {short_name} was not found" # pragma: no cover
24+
25+
26+
# TODO: this method belongs on the nzshm-model gmcm class
27+
@lru_cache
28+
def get_logic_tree_branch(model_version, branch_set_short_name, gsim_name, gsim_args):
29+
log.info(
30+
f"get_logic_tree_branch: {branch_set_short_name} gsim_name: {gsim_name} gsim_args: {gsim_args}"
31+
)
32+
branch_set = get_branch_set(model_version, branch_set_short_name)
33+
for ltb in branch_set.branches:
34+
if (ltb.gsim_name == gsim_name) and (ltb.gsim_args == json.loads(gsim_args)):
35+
return ltb
36+
assert (
37+
0
38+
), f"branch with gsim_name: {gsim_name} gsim_args: {gsim_args} was not found" # pragma: no cover
39+
40+
41+
class GmmLogicTreeBranch(graphene.ObjectType):
42+
class Meta:
43+
interfaces = (relay.Node,)
44+
45+
model_version = graphene.String()
46+
branch_set_short_name = graphene.String()
47+
gsim_name = graphene.String()
48+
gsim_args = graphene.JSONString()
49+
tectonic_region_type = graphene.String() # should be an enum
50+
weight = graphene.Float()
51+
52+
def resolve_id(self, info):
53+
return f"{self.model_version}|{self.branch_set_short_name}|{self.gsim_name}|{json.dumps(self.gsim_args)}"
54+
55+
@classmethod
56+
def get_node(cls, info, node_id: str):
57+
model_version, branch_set_short_name, gsim_name, gsim_args = node_id.split("|")
58+
gltb = get_logic_tree_branch(
59+
model_version, branch_set_short_name, gsim_name, gsim_args
60+
)
61+
return GmmLogicTreeBranch(
62+
model_version=model_version,
63+
branch_set_short_name=branch_set_short_name,
64+
tectonic_region_type=gltb.tectonic_region_type,
65+
gsim_name=gltb.gsim_name,
66+
gsim_args=gltb.gsim_args,
67+
weight=gltb.weight,
68+
)
69+
70+
71+
class GmmBranchSet(graphene.ObjectType):
72+
"""Ground Motion Model branch sets,
73+
74+
to ensure that the wieghts of the enclosed branches sum to 1.0
75+
"""
76+
77+
class Meta:
78+
interfaces = (relay.Node,)
79+
80+
model_version = graphene.String()
81+
short_name = graphene.String()
82+
long_name = graphene.String()
83+
tectonic_region_type = graphene.String()
84+
branches = graphene.List(GmmLogicTreeBranch)
85+
86+
def resolve_id(self, info):
87+
return f"{self.model_version}:{self.short_name}"
88+
89+
@classmethod
90+
def get_node(cls, info, node_id: str):
91+
model_version, short_name = node_id.split(":")
92+
bs = get_branch_set(model_version, short_name)
93+
return GmmBranchSet(
94+
model_version=model_version,
95+
tectonic_region_type=bs.tectonic_region_type,
96+
short_name=bs.short_name,
97+
long_name=bs.long_name,
98+
)
99+
100+
@staticmethod
101+
def resolve_branches(root, info, **kwargs):
102+
log.info(f"resolve_branches root: {root} kwargs: {kwargs}")
103+
bs = get_branch_set(root.model_version, root.short_name)
104+
for ltb in bs.branches:
105+
log.debug(ltb)
106+
yield GmmLogicTreeBranch(
107+
model_version=root.model_version,
108+
branch_set_short_name=root.short_name,
109+
tectonic_region_type=ltb.tectonic_region_type,
110+
weight=ltb.weight,
111+
gsim_name=ltb.gsim_name,
112+
gsim_args=ltb.gsim_args,
113+
)
114+
115+
116+
class GroundMotionModelLogicTree(graphene.ObjectType):
117+
"""A custom Node representing the GMM logic tree of a given model."""
118+
119+
class Meta:
120+
interfaces = (relay.Node,)
121+
122+
model_version = graphene.String()
123+
branch_sets = graphene.List(GmmBranchSet)
124+
125+
def resolve_id(self, info):
126+
return self.model_version
127+
128+
@classmethod
129+
def get_node(cls, info, model_version: str):
130+
return GroundMotionModelLogicTree(model_version=model_version)
131+
132+
@staticmethod
133+
def resolve_branch_sets(root, info, **kwargs):
134+
log.info(f"resolve_branch_sets root: {root} kwargs: {kwargs}")
135+
glt = get_model_by_version(root.model_version).gmm_logic_tree
136+
for bs in glt.branch_sets:
137+
yield GmmBranchSet(
138+
model_version=root.model_version,
139+
short_name=bs.short_name,
140+
long_name=bs.long_name,
141+
tectonic_region_type=bs.tectonic_region_type,
142+
)

nshm_model_graphql_api/schema/nshm_model_schema.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import nzshm_model as nm
88
from graphene import relay
99

10+
from .nshm_model_gmms_schema import GroundMotionModelLogicTree
1011
from .nshm_model_sources_schema import SourceLogicTree
1112

1213
log = logging.getLogger(__name__)
@@ -21,16 +22,20 @@ class Meta:
2122
version = graphene.String()
2223
title = graphene.String()
2324
source_logic_tree = graphene.Field(SourceLogicTree)
25+
gmm_logic_tree = graphene.Field(GroundMotionModelLogicTree)
2426

2527
def resolve_id(self, info):
2628
return self.version
2729

2830
@staticmethod
2931
def resolve_source_logic_tree(root, info, **kwargs):
3032
log.info(f"resolve_source_logic_tree root: {root} kwargs: {kwargs}")
31-
return SourceLogicTree(
32-
model_version=root.version
33-
) # , branch_sets=get_branch_sets(slt))
33+
return SourceLogicTree(model_version=root.version)
34+
35+
@staticmethod
36+
def resolve_gmm_logic_tree(root, info, **kwargs):
37+
log.info(f"resolve_gmm_logic_tree root: {root} kwargs: {kwargs}")
38+
return GroundMotionModelLogicTree(model_version=root.version)
3439

3540
@classmethod
3641
def get_node(cls, info, version: str):

poetry.lock

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import pytest
2+
from graphene.test import Client
3+
from graphql_relay import to_global_id
4+
5+
from nshm_model_graphql_api import schema
6+
7+
8+
@pytest.fixture(scope="module")
9+
def client():
10+
return Client(schema.schema_root)
11+
12+
13+
@pytest.mark.parametrize(
14+
"model_version",
15+
["NSHM_v1.0.0", "NSHM_v1.0.4"],
16+
)
17+
def test_get_model_SourceLogicTree_as_node(client, model_version):
18+
QUERY = """
19+
query {
20+
node(id: "%s")
21+
{
22+
... on Node {
23+
id
24+
}
25+
... on GroundMotionModelLogicTree {
26+
model_version
27+
}
28+
}
29+
}
30+
""" % to_global_id(
31+
"GroundMotionModelLogicTree", model_version
32+
)
33+
print(QUERY)
34+
executed = client.execute(QUERY)
35+
print(executed)
36+
assert executed["data"]["node"]["model_version"] == model_version
37+
assert executed["data"]["node"]["id"] == to_global_id(
38+
"GroundMotionModelLogicTree", model_version
39+
)
40+
41+
42+
@pytest.mark.parametrize(
43+
"model_version, short_name, long_name",
44+
[
45+
("NSHM_v1.0.0", "CRU", "Crustal"),
46+
("NSHM_v1.0.0", "SLAB", "Subduction Intraslab"),
47+
("NSHM_v1.0.4", "CRU", "Crustal"),
48+
("NSHM_v1.0.4", "INTER", "Subduction Interface"),
49+
],
50+
)
51+
def test_get_model_GmmBranchSet_as_node(client, model_version, short_name, long_name):
52+
QUERY = """
53+
query {
54+
node(id: "%s")
55+
{
56+
... on Node {
57+
id
58+
}
59+
... on GmmBranchSet {
60+
model_version
61+
short_name
62+
long_name
63+
tectonic_region_type
64+
}
65+
66+
}
67+
}
68+
""" % to_global_id(
69+
"GmmBranchSet", f"{model_version}:{short_name}"
70+
)
71+
executed = client.execute(QUERY)
72+
print(executed)
73+
assert executed["data"]["node"]["model_version"] == model_version
74+
assert executed["data"]["node"]["short_name"] == short_name
75+
assert executed["data"]["node"]["long_name"] == long_name
76+
assert executed["data"]["node"]["id"] == to_global_id(
77+
"GmmBranchSet", f"{model_version}:{short_name}"
78+
)
79+
80+
81+
@pytest.mark.parametrize(
82+
"model_version, branch_set_short_name, gsim_name, gsim_args, weight",
83+
[
84+
(
85+
"NSHM_v1.0.0",
86+
"CRU",
87+
"Stafford2022",
88+
'{"mu_branch": "Upper"}',
89+
0.117,
90+
),
91+
(
92+
"NSHM_v1.0.4",
93+
"INTER",
94+
"Atkinson2022SInter",
95+
'{"epistemic": "Lower", "modified_sigma": "true"}',
96+
0.081,
97+
),
98+
],
99+
)
100+
def test_get_model_GmmLogicTreeBranch_as_node(
101+
client, model_version, branch_set_short_name, gsim_name, gsim_args, weight
102+
):
103+
QUERY = """
104+
query {
105+
node(id: "%s")
106+
{
107+
... on Node {
108+
id
109+
}
110+
... on GmmLogicTreeBranch {
111+
model_version
112+
branch_set_short_name
113+
gsim_name
114+
gsim_args
115+
weight
116+
}
117+
118+
}
119+
}
120+
""" % to_global_id(
121+
"GmmLogicTreeBranch",
122+
f"{model_version}|{branch_set_short_name}|{gsim_name}|{gsim_args}",
123+
)
124+
executed = client.execute(QUERY)
125+
print(executed)
126+
assert executed["data"]["node"]["id"] == to_global_id(
127+
"GmmLogicTreeBranch",
128+
f"{model_version}|{branch_set_short_name}|{gsim_name}|{gsim_args}",
129+
)
130+
131+
assert executed["data"]["node"]["model_version"] == model_version
132+
assert executed["data"]["node"]["branch_set_short_name"] == branch_set_short_name
133+
assert executed["data"]["node"]["gsim_name"] == gsim_name
134+
assert executed["data"]["node"]["weight"] == weight

0 commit comments

Comments
 (0)