Skip to content

Commit 0c3928a

Browse files
Re-release 2.8.0 (#593)
2 parents fa7fc5e + e7ccc47 commit 0c3928a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+933
-415
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ Released changes are shown in the
1919
### Removed
2020

2121
### Fixed
22+
- Fixed importing the same proposed actions CSV file twice
2223

2324
### Security

azimuth/app.py

+33-31
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def create_app() -> FastAPI:
155155
Returns:
156156
FastAPI.
157157
"""
158-
app = FastAPI(
158+
api = FastAPI(
159159
title="Azimuth API",
160160
description="Azimuth API",
161161
version="1.0",
@@ -171,101 +171,103 @@ def create_app() -> FastAPI:
171171
)
172172

173173
# Setup routes
174-
from azimuth.routers.app import router as app_router
175-
from azimuth.routers.class_overlap import router as class_overlap_router
176-
from azimuth.routers.config import router as config_router
177-
from azimuth.routers.custom_utterances import router as custom_utterances_router
178-
from azimuth.routers.dataset_warnings import router as dataset_warnings_router
179-
from azimuth.routers.export import router as export_router
180-
from azimuth.routers.model_performance.confidence_histogram import (
181-
router as confidence_histogram_router,
174+
from azimuth.routers import (
175+
app,
176+
class_overlap,
177+
config,
178+
custom_utterances,
179+
dataset_warnings,
180+
export,
181+
top_words,
182+
utterances,
183+
)
184+
from azimuth.routers.model_performance import (
185+
confidence_histogram,
186+
confusion_matrix,
187+
metrics,
188+
outcome_count,
189+
utterance_count,
182190
)
183-
from azimuth.routers.model_performance.confusion_matrix import router as confusion_matrix_router
184-
from azimuth.routers.model_performance.metrics import router as metrics_router
185-
from azimuth.routers.model_performance.outcome_count import router as outcome_count_router
186-
from azimuth.routers.model_performance.utterance_count import router as utterance_count_router
187-
from azimuth.routers.top_words import router as top_words_router
188-
from azimuth.routers.utterances import router as utterances_router
189191
from azimuth.utils.routers import require_application_ready, require_available_model
190192

191193
api_router = APIRouter()
192-
api_router.include_router(app_router, prefix="", tags=["App"])
193-
api_router.include_router(config_router, prefix="/config", tags=["Config"])
194+
api_router.include_router(app.router, prefix="", tags=["App"])
195+
api_router.include_router(config.router, prefix="/config", tags=["Config"])
194196
api_router.include_router(
195-
class_overlap_router,
197+
class_overlap.router,
196198
prefix="/dataset_splits/{dataset_split_name}/class_overlap",
197199
tags=["Class Overlap"],
198200
dependencies=[Depends(require_application_ready)],
199201
)
200202
api_router.include_router(
201-
confidence_histogram_router,
203+
confidence_histogram.router,
202204
prefix="/dataset_splits/{dataset_split_name}/confidence_histogram",
203205
tags=["Confidence Histogram"],
204206
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
205207
)
206208
api_router.include_router(
207-
dataset_warnings_router,
209+
dataset_warnings.router,
208210
prefix="/dataset_warnings",
209211
tags=["Dataset Warnings"],
210212
dependencies=[Depends(require_application_ready)],
211213
)
212214
api_router.include_router(
213-
metrics_router,
215+
metrics.router,
214216
prefix="/dataset_splits/{dataset_split_name}/metrics",
215217
tags=["Metrics"],
216218
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
217219
)
218220
api_router.include_router(
219-
outcome_count_router,
221+
outcome_count.router,
220222
prefix="/dataset_splits/{dataset_split_name}/outcome_count",
221223
tags=["Outcome Count"],
222224
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
223225
)
224226
api_router.include_router(
225-
utterance_count_router,
227+
utterance_count.router,
226228
prefix="/dataset_splits/{dataset_split_name}/utterance_count",
227229
tags=["Utterance Count"],
228230
dependencies=[Depends(require_application_ready)],
229231
)
230232
api_router.include_router(
231-
utterances_router,
233+
utterances.router,
232234
prefix="/dataset_splits/{dataset_split_name}/utterances",
233235
tags=["Utterances"],
234236
dependencies=[Depends(require_application_ready)],
235237
)
236238
api_router.include_router(
237-
export_router,
239+
export.router,
238240
prefix="/export",
239241
tags=["Export"],
240242
dependencies=[Depends(require_application_ready)],
241243
)
242244
api_router.include_router(
243-
custom_utterances_router,
245+
custom_utterances.router,
244246
prefix="/custom_utterances",
245247
tags=["Custom Utterances"],
246248
dependencies=[Depends(require_application_ready)],
247249
)
248250
api_router.include_router(
249-
top_words_router,
251+
top_words.router,
250252
prefix="/dataset_splits/{dataset_split_name}/top_words",
251253
tags=["Top Words"],
252254
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
253255
)
254256
api_router.include_router(
255-
confusion_matrix_router,
257+
confusion_matrix.router,
256258
prefix="/dataset_splits/{dataset_split_name}/confusion_matrix",
257259
tags=["Confusion Matrix"],
258260
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
259261
)
260-
app.include_router(api_router)
262+
api.include_router(api_router)
261263

262-
app.add_middleware(
264+
api.add_middleware(
263265
CORSMiddleware,
264266
allow_methods=["*"],
265267
allow_headers=["*"],
266268
)
267269

268-
return app
270+
return api
269271

270272

271273
def load_dataset_split_managers_from_config(

azimuth/config.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,15 @@ def get_project_hash(self):
284284

285285
class ArtifactsConfig(AzimuthBaseSettings, extra=Extra.ignore):
286286
artifact_path: str = Field(
287-
"cache",
287+
default_factory=lambda: os.path.abspath("cache"),
288288
description="Where to store artifacts (Azimuth config history, HDF5 files, HF datasets).",
289289
exclude_from_cache=True,
290290
)
291291

292+
@validator("artifact_path")
293+
def validate_artifact_path(cls, artifact_path):
294+
return os.path.abspath(artifact_path)
295+
292296
def get_config_history_path(self):
293297
return f"{self.artifact_path}/config_history.jsonl"
294298

@@ -329,7 +333,7 @@ class ModelContractConfig(CommonFieldsConfig):
329333
# Uncertainty configuration
330334
uncertainty: UncertaintyOptions = UncertaintyOptions()
331335
# Layer name where to calculate the gradients, normally the word embeddings layer.
332-
saliency_layer: Optional[str] = Field(None, nullable=True)
336+
saliency_layer: Union[Literal["auto"], str, None] = Field("auto", nullable=True)
333337

334338
@validator("pipelines", pre=True)
335339
def _check_pipeline_names(cls, pipeline_definitions):

azimuth/modules/model_contracts/hf_text_classification.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from azimuth.types.task import PredictionResponse, SaliencyResponse
1717
from azimuth.utils.ml.mc_dropout import MCDropout
1818
from azimuth.utils.ml.saliency import (
19+
find_word_embeddings_layer,
1920
get_saliency,
2021
register_embedding_gradient_hook,
2122
register_embedding_list_hook,
@@ -127,6 +128,7 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
127128
raise ValueError("This method should not be called when saliency_layer is not defined.")
128129

129130
hf_pipeline = self.get_model()
131+
hf_model = hf_pipeline.model
130132

131133
inputs = hf_pipeline.tokenizer(
132134
batch[self.config.columns.text_input],
@@ -140,18 +142,19 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
140142
inputs["input_ids"] = inputs["input_ids"].to(hf_pipeline.device)
141143
inputs["attention_mask"] = inputs["attention_mask"].to(hf_pipeline.device)
142144

143-
logits = hf_pipeline.model(**inputs)[0]
145+
logits = hf_model(**inputs)[0]
144146
output = torch.softmax(logits, dim=1).detach().cpu().numpy()
145147
prediction = output.argmax(-1)
146148

147-
embeddings_list: List[np.ndarray] = []
148-
handle = register_embedding_list_hook(
149-
hf_pipeline.model, embeddings_list, self.saliency_layer
149+
embedding_layer = (
150+
hf_model.base_model.get_input_embeddings()
151+
if self.saliency_layer == "auto"
152+
else find_word_embeddings_layer(hf_model, self.saliency_layer)
150153
)
154+
embeddings_list: List[np.ndarray] = []
155+
handle = register_embedding_list_hook(embeddings_list, embedding_layer)
151156
embeddings_gradients: List[np.ndarray] = []
152-
hook = register_embedding_gradient_hook(
153-
hf_pipeline.model, embeddings_gradients, self.saliency_layer
154-
)
157+
hook = register_embedding_gradient_hook(embeddings_gradients, embedding_layer)
155158

156159
filter_class = self.mod_options.filter_class
157160
selected_classes = (
@@ -162,8 +165,8 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
162165
)
163166

164167
# Do backward pass to compute gradients
165-
hf_pipeline.model.zero_grad()
166-
_loss = hf_pipeline.model(**inputs)[0] # loss is at index 0 when passing labels
168+
hf_model.zero_grad()
169+
_loss = hf_model(**inputs)[0] # loss is at index 0 when passing labels
167170
_loss.backward()
168171
handle.remove()
169172
hook.remove()

azimuth/routers/config.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# Copyright ServiceNow, Inc. 2021 – 2022
22
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
33
# in the root directory of this source tree.
4-
from typing import Any, Dict, List
4+
from typing import Dict, List
55

66
import structlog
77
from fastapi import APIRouter, Body, Depends, HTTPException, Query
8-
from pydantic import ValidationError
9-
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
8+
from starlette.status import (
9+
HTTP_400_BAD_REQUEST,
10+
HTTP_403_FORBIDDEN,
11+
HTTP_500_INTERNAL_SERVER_ERROR,
12+
)
1013

1114
from azimuth.app import (
1215
get_config,
@@ -73,6 +76,25 @@ def get_config_def(
7376
return config
7477

7578

79+
@router.patch(
80+
"/validate",
81+
summary="Validate config",
82+
description="Validate the given partial config update and return the complete config that would"
83+
" result if this update was applied.",
84+
response_model=AzimuthConfig,
85+
dependencies=[Depends(require_editable_config)],
86+
)
87+
def validate_config(
88+
config: AzimuthConfig = Depends(get_config),
89+
partial_config: Dict = Body(...),
90+
) -> AzimuthConfig:
91+
new_config = update_config(old_config=config, partial_config=partial_config)
92+
93+
assert_permission_to_update_config(old_config=config, new_config=new_config)
94+
95+
return new_config
96+
97+
7698
@router.patch(
7799
"",
78100
summary="Update config",
@@ -85,19 +107,17 @@ def patch_config(
85107
config: AzimuthConfig = Depends(get_config),
86108
partial_config: Dict = Body(...),
87109
) -> AzimuthConfig:
88-
if attribute_changed_in_config("artifact_path", partial_config, config):
89-
raise HTTPException(
90-
HTTP_400_BAD_REQUEST,
91-
detail="Cannot edit artifact_path, otherwise config history would become inconsistent.",
92-
)
110+
log.info(f"Validating config change with {partial_config}.")
111+
new_config = update_config(old_config=config, partial_config=partial_config)
112+
113+
assert_permission_to_update_config(old_config=config, new_config=new_config)
114+
115+
if new_config.large_dask_cluster != config.large_dask_cluster:
116+
cluster = default_cluster(new_config.large_dask_cluster)
117+
else:
118+
cluster = task_manager.cluster
93119

94120
try:
95-
log.info(f"Validating config change with {partial_config}.")
96-
new_config = update_config(old_config=config, partial_config=partial_config)
97-
if attribute_changed_in_config("large_dask_cluster", partial_config, config):
98-
cluster = default_cluster(partial_config["large_dask_cluster"])
99-
else:
100-
cluster = task_manager.cluster
101121
run_startup_tasks(new_config, cluster)
102122
log.info(f"Config successfully updated with {partial_config}.")
103123
except Exception as e:
@@ -107,8 +127,6 @@ def patch_config(
107127
log.info("Config update cancelled.")
108128
if isinstance(e, AzimuthValidationError):
109129
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
110-
if isinstance(e, ValidationError):
111-
raise
112130
else:
113131
raise HTTPException(
114132
HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config."
@@ -117,7 +135,9 @@ def patch_config(
117135
return new_config
118136

119137

120-
def attribute_changed_in_config(
121-
attribute: str, partial_config: Dict[str, Any], config: AzimuthConfig
122-
) -> bool:
123-
return attribute in partial_config and partial_config[attribute] != getattr(config, attribute)
138+
def assert_permission_to_update_config(*, old_config: AzimuthConfig, new_config: AzimuthConfig):
139+
if old_config.artifact_path != new_config.artifact_path:
140+
raise HTTPException(
141+
HTTP_403_FORBIDDEN,
142+
detail="Cannot edit artifact_path, otherwise config history would become inconsistent.",
143+
)

azimuth/utils/ml/saliency.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import structlog
8+
from torch.nn import Embedding
89

910
from azimuth.types.general.module_arguments import GradientCalculation
1011

@@ -36,14 +37,13 @@ def find_word_embeddings_layer(model: Any, layer_name: str) -> Any:
3637

3738

3839
def register_embedding_list_hook(
39-
model: Any, embeddings_list: List[np.ndarray], layer_name: str
40+
embeddings_list: List[np.ndarray], embedding_layer: Embedding
4041
) -> Any:
4142
"""Register hook to get the embedding values from model.
4243
4344
Args:
44-
model: Model.
4545
embeddings_list: Variable to save values.
46-
layer_name: Name of the embedding layer.
46+
embedding_layer: Embedding layer on which to compute the saliency map.
4747
4848
Returns:
4949
Hook.
@@ -52,21 +52,17 @@ def register_embedding_list_hook(
5252
def forward_hook(module, inputs, output):
5353
embeddings_list.append(output.detach().cpu().clone().numpy())
5454

55-
embedding_layer = find_word_embeddings_layer(model, layer_name)
56-
handle = embedding_layer.register_forward_hook(forward_hook)
57-
58-
return handle
55+
return embedding_layer.register_forward_hook(forward_hook)
5956

6057

6158
def register_embedding_gradient_hook(
62-
model: Any, embeddings_gradients: List[np.ndarray], layer_name: str
59+
embeddings_gradients: List[np.ndarray], embedding_layer: Embedding
6360
) -> Any:
6461
"""Register hook to get the gradient values from the embedding layer.
6562
6663
Args:
67-
model: Model.
6864
embeddings_gradients: Variable to save values.
69-
layer_name: Name of the embedding layer.
65+
embedding_layer: Embedding layer on which to compute the saliency map.
7066
7167
Returns:
7268
Hook.
@@ -76,10 +72,7 @@ def register_embedding_gradient_hook(
7672
def hook_layers(module, grad_in, grad_out):
7773
embeddings_gradients.append(grad_out[0].detach().cpu().clone().numpy())
7874

79-
embedding_layer = find_word_embeddings_layer(model, layer_name)
80-
hook = embedding_layer.register_full_backward_hook(hook_layers)
81-
82-
return hook
75+
return embedding_layer.register_full_backward_hook(hook_layers)
8376

8477

8578
def get_saliency(

config/development/clinc/conf.json

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
}
2121
},
2222
"batch_size": 64,
23-
"saliency_layer": "distilbert.embeddings.word_embeddings",
2423
"model_contract": "hf_text_classification",
2524
"rejection_class": "NO_INTENT"
2625
}

0 commit comments

Comments
 (0)