diff --git a/mmf/common/report.py b/mmf/common/report.py index 179fb6b3c..28526cf11 100644 --- a/mmf/common/report.py +++ b/mmf/common/report.py @@ -98,7 +98,7 @@ def apply_fn(self, fn: Callable, fields: Optional[List[str]] = None): if key not in fields: continue self[key] = fn(self[key]) - if isinstance(self[key], collections.MutableSequence): + if isinstance(self[key], collections.abc.MutableSequence): for idx, item in enumerate(self[key]): self[key][idx] = fn(item) elif isinstance(self[key], dict): diff --git a/mmf/common/sample.py b/mmf/common/sample.py index bafffa460..7b8f96966 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -424,7 +424,7 @@ def convert_batch_to_sample_list( def to_device( sample_list: Union[SampleList, Dict[str, Any]], device: device_type = "cuda" ) -> SampleList: - if isinstance(sample_list, collections.Mapping): + if isinstance(sample_list, collections.abc.Mapping): sample_list = convert_batch_to_sample_list(sample_list) # to_device is specifically for SampleList # if user is passing something custom built diff --git a/mmf/models/transformers/heads/utils.py b/mmf/models/transformers/heads/utils.py index 20ba23dce..314ad4162 100644 --- a/mmf/models/transformers/heads/utils.py +++ b/mmf/models/transformers/heads/utils.py @@ -147,10 +147,10 @@ def _process_head_output( head_name: str, sample_list: Dict[str, Tensor], ) -> Dict[str, Tensor]: - if isinstance(outputs, collections.MutableMapping) and "losses" in outputs: + if isinstance(outputs, collections.abc.MutableMapping) and "losses" in outputs: return outputs - if isinstance(outputs, collections.MutableMapping) and "scores" in outputs: + if isinstance(outputs, collections.abc.MutableMapping) and "scores" in outputs: logits = outputs["scores"] else: logits = outputs diff --git a/mmf/utils/logger.py b/mmf/utils/logger.py index 9fe03a508..2153f4c4a 100644 --- a/mmf/utils/logger.py +++ b/mmf/utils/logger.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import collections +import collections.abc import functools import json import logging @@ -288,7 +288,7 @@ def log_progress(info: Union[Dict, Any], log_format="simple"): caller, key = _find_caller() logger = logging.getLogger(caller) - if not isinstance(info, collections.Mapping): + if not isinstance(info, collections.abc.Mapping): logger.info(info) if log_format == "simple":