Skip to content

Commit 9b34e1f

Browse files
committed
[feat] MMFT MLP head output key; mix model losses and MMFLosses
This will allow doing multitasking on single batch easily through different heads with different output_key per task Mixing allows usage of heads that returns losses directly with heads that don't and rely on MMFLoss Test Plan: Tests have been added
1 parent 08f062e commit 9b34e1f

File tree

3 files changed

+69
-13
lines changed

3 files changed

+69
-13
lines changed

mmf/models/base_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,23 +272,27 @@ def __call__(self, sample_list, *args, **kwargs):
272272
model_output, collections.abc.Mapping
273273
), "A dict must be returned from the forward of the model."
274274

275+
final_output = {"losses": {}}
276+
final_output.update(model_output)
277+
275278
if "losses" in model_output:
276-
if not self._logged_warning["losses_present"]:
279+
assert isinstance(
280+
model_output["losses"], collections.abc.Mapping
281+
), "'losses' returned from the model must be a dict."
282+
283+
if hasattr(self, "losses"):
284+
if "losses" in model_output and not self._logged_warning["losses_present"]:
277285
warnings.warn(
278-
"'losses' already present in model output. "
279-
"No calculation will be done in base model."
286+
"'losses' already present in model output and 'loss' key "
287+
"was specified. Results from the two will be merged. "
288+
"If this is not expected, either (i) assign unique keys to "
289+
"losses returned from your model (ii) remove 'loss' key from "
290+
"your model output"
280291
)
281292
self._logged_warning["losses_present"] = True
293+
final_output["losses"].update(self.losses(sample_list, model_output))
282294

283-
assert isinstance(
284-
model_output["losses"], collections.abc.Mapping
285-
), "'losses' must be a dict."
286-
elif hasattr(self, "losses"):
287-
model_output["losses"] = self.losses(sample_list, model_output)
288-
else:
289-
model_output["losses"] = {}
290-
291-
return model_output
295+
return final_output
292296

293297
def load_requirements(self, *args, **kwargs):
294298
requirements = self.config.get("zoo_requirements", [])

mmf/models/transformers/heads/mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Config(BaseTransformerHead.Config):
2020
hidden_dropout_prob: float = 0.1
2121
layer_norm_eps: float = 1e-6
2222
hidden_act: str = "gelu"
23+
output_key: str = "scores"
2324

2425
def __init__(self, config: Config, *args, **kwargs):
2526
super().__init__(config, *args, **kwargs)
@@ -33,6 +34,7 @@ def __init__(self, config: Config, *args, **kwargs):
3334
)
3435
self.num_labels = self.config.num_labels
3536
self.hidden_size = self.config.hidden_size
37+
self.output_key = self.config.get("output_key", "scores")
3638

3739
def forward(
3840
self,
@@ -46,5 +48,5 @@ def forward(
4648
output_dict = {}
4749
pooled_output = self.pooler(sequence_output)
4850
prediction = self.classifier(pooled_output)
49-
output_dict["scores"] = prediction.view(-1, self.num_labels)
51+
output_dict[self.output_key] = prediction.view(-1, self.num_labels)
5052
return output_dict

tests/models/test_base_model.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
import unittest
4+
5+
import torch
6+
from mmf.common.sample import SampleList
7+
from mmf.models.base_model import BaseModel
8+
from tests.test_utils import compare_tensors
9+
10+
11+
class LocalTestModelWithForwardLoss(BaseModel):
12+
def forward(self, *args, **kwargs):
13+
return {"losses": {"x": torch.tensor(1.0)}}
14+
15+
16+
class LocalTestModelWithNoLoss(BaseModel):
17+
def forward(self, *args, **kwargs):
18+
return {}
19+
20+
21+
class LocalTestModelWithLossAttribute(BaseModel):
22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
self.losses = lambda x, y: {"x": torch.tensor(2.0)}
25+
26+
def forward(self, *args, **kwargs):
27+
return {}
28+
29+
30+
class TestBaseModel(unittest.TestCase):
31+
def test_forward_loss(self):
32+
sample_list = SampleList()
33+
sample_list.add_field("x", torch.tensor(1))
34+
model = LocalTestModelWithForwardLoss({})
35+
with torch.no_grad():
36+
output = model(sample_list)
37+
self.assertTrue("losses" in output)
38+
self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(1.0)))
39+
40+
model = LocalTestModelWithLossAttribute({})
41+
with torch.no_grad():
42+
output = model(sample_list)
43+
self.assertTrue("losses" in output)
44+
self.assertTrue(compare_tensors(output["losses"]["x"], torch.tensor(2.0)))
45+
46+
model = LocalTestModelWithNoLoss({})
47+
with torch.no_grad():
48+
output = model(sample_list)
49+
self.assertTrue("losses" in output)
50+
self.assertEqual(output["losses"], {})

0 commit comments

Comments
 (0)