diff --git a/pyproject.toml b/pyproject.toml index 5997432..eb47592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,11 +28,8 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "datasets", "docstring-parser", "einops", - "numpy<2.0.0", - "seaborn", "tabulate", "torch>=2.2", ] @@ -46,6 +43,12 @@ dynamic = ["version"] [project.optional-dependencies] dev = ["check-manifest"] test = ["pytest"] +analysis = [ + "datasets", + "matplotlib", + "pandas", + "seaborn", +] [tool.setuptools] packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"] @@ -55,3 +58,7 @@ version = {attr = "unit_scaling._version.__version__"} [tool.setuptools_scm] version_file = "unit_scaling/_version.py" + +[tool.isort] +profile = "black" +extend_skip = ["unit_scaling/_version.py"] diff --git a/requirements-dev.txt b/requirements-dev.txt index e8946f6..3318a98 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ datasets==3.1.0 docstring-parser==0.16 einops==0.8.0 -numpy==1.26.4 +numpy==2.2.6 seaborn==0.13.2 tabulate==0.9.0 torch==2.5.1+cpu diff --git a/unit_scaling/__init__.py b/unit_scaling/__init__.py index 42357ae..cab0a46 100644 --- a/unit_scaling/__init__.py +++ b/unit_scaling/__init__.py @@ -26,7 +26,6 @@ TransformerLayer, ) from ._version import __version__ -from .analysis import visualiser from .core.functional import transformer_residual_scaling_rule from .parameter import MupType, Parameter @@ -58,6 +57,5 @@ # Functions "Parameter", "transformer_residual_scaling_rule", - "visualiser", "__version__", ] diff --git a/unit_scaling/analysis.py b/unit_scaling/analysis.py index 1b47001..6ddb092 100644 --- a/unit_scaling/analysis.py +++ b/unit_scaling/analysis.py @@ -8,15 +8,21 @@ from math import isnan from typing import TYPE_CHECKING, Any, List, Optional, Tuple -import matplotlib -import matplotlib.colors -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns # type: ignore[import-untyped] -from datasets import load_dataset # type: ignore[import-untyped] -from torch import Tensor, nn -from torch.fx.graph import Graph -from torch.fx.node import Node +try: + import matplotlib + import matplotlib.colors + import matplotlib.pyplot as plt + import pandas as pd + import seaborn as sns # type: ignore[import-untyped] + from datasets import load_dataset # type: ignore[import-untyped] + from torch import Tensor, nn + from torch.fx.graph import Graph + from torch.fx.node import Node +except ImportError as e: + raise ImportError( + "Optional dependencies for `unit_scaling.analysis` are missing." + " Please install `unit-scaling[analysis]`" + ) from e from ._internal_utils import generate__all__ from .transforms import (