Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"datasets",
"docstring-parser",
"einops",
"numpy<2.0.0",
"seaborn",
"tabulate",
"torch>=2.2",
]
Expand All @@ -46,6 +43,12 @@ dynamic = ["version"]
[project.optional-dependencies]
dev = ["check-manifest"]
test = ["pytest"]
analysis = [
"datasets",
"matplotlib",
"pandas",
"seaborn",
]

[tool.setuptools]
packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"]
Expand All @@ -55,3 +58,7 @@ version = {attr = "unit_scaling._version.__version__"}

[tool.setuptools_scm]
version_file = "unit_scaling/_version.py"

[tool.isort]
profile = "black"
extend_skip = ["unit_scaling/_version.py"]
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
datasets==3.1.0
docstring-parser==0.16
einops==0.8.0
numpy==1.26.4
numpy==2.2.6
seaborn==0.13.2
tabulate==0.9.0
torch==2.5.1+cpu
Expand Down
2 changes: 0 additions & 2 deletions unit_scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
TransformerLayer,
)
from ._version import __version__
from .analysis import visualiser
from .core.functional import transformer_residual_scaling_rule
from .parameter import MupType, Parameter

Expand Down Expand Up @@ -58,6 +57,5 @@
# Functions
"Parameter",
"transformer_residual_scaling_rule",
"visualiser",
"__version__",
]
24 changes: 15 additions & 9 deletions unit_scaling/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@
from math import isnan
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

import matplotlib
import matplotlib.colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns # type: ignore[import-untyped]
from datasets import load_dataset # type: ignore[import-untyped]
from torch import Tensor, nn
from torch.fx.graph import Graph
from torch.fx.node import Node
try:
import matplotlib
import matplotlib.colors
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns # type: ignore[import-untyped]
from datasets import load_dataset # type: ignore[import-untyped]
from torch import Tensor, nn
from torch.fx.graph import Graph
from torch.fx.node import Node
except ImportError as e:
raise ImportError(
"Optional dependencies for `unit_scaling.analysis` are missing."
" Please install `unit-scaling[analysis]`"
) from e

from ._internal_utils import generate__all__
from .transforms import (
Expand Down