Skip to content

Commit

Permalink
Support log_images for aim tracker (#2257)
Browse files Browse the repository at this point in the history
* support `log_images` for aim tracker

* fix the potential kwargs issue for aim tracker's `log_images`

* remove ambiguous import statement

* use `aim` directly to avoid potential conflict
  • Loading branch information
Justin900429 authored Dec 15, 2023
1 parent 6b2d968 commit 0606784
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,38 @@ def log(self, values: dict, step: Optional[int], **kwargs):
for key, value in values.items():
self.writer.track(value, name=key, step=step, **kwargs)

@on_main_process
def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[Dict[str, dict]] = None):
"""
Logs `images` to the current run.
Args:
values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`):
Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a
tuple is provided, the first element should be the image and the second element should be the caption.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs (`Dict[str, dict]`):
Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the
keys `aim_image` and `track`, respectively.
"""
import aim

aim_image_kw = {}
track_kw = {}

if kwargs is not None:
aim_image_kw = kwargs.get("aim_image", {})
track_kw = kwargs.get("track", {})

for key, value in values.items():
if isinstance(value, tuple):
img, caption = value
else:
img, caption = value, ""
aim_image = aim.Image(img, caption=caption, **aim_image_kw)
self.writer.track(aim_image, name=key, step=step, **track_kw)

@on_main_process
def finish(self):
"""
Expand Down Expand Up @@ -936,7 +968,8 @@ def finish(self):


def filter_trackers(
log_with: List[Union[str, LoggerType, GeneralTracker]], logging_dir: Union[str, os.PathLike] = None
log_with: List[Union[str, LoggerType, GeneralTracker]],
logging_dir: Union[str, os.PathLike] = None,
):
"""
Takes in a list of potential tracker types and checks that:
Expand Down

0 comments on commit 0606784

Please sign in to comment.