Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,19 @@ uv sync
- basic: `uv run python src/train.py --n-ncas 3 --epochs 1000 --device cpu`
- wandb logging: `uv run python src/train.py --n-ncas 3 --epochs 10000 --device cuda --wandb`
- run with config: `uv run python src/train.py --config configs/example.json`
- live viz training: `uv run python src/train.py --n-ncas 3 --epochs 1000 --device cpu --live-viz`
- 3D example: `uv run python src/train.py --n-ncas 3 --epochs 200 --device cpu --live-viz --viz-slice-axis depth`
- 3D grid size via CLI: `uv run python src/train.py --n-ncas 3 --epochs 200 --grid-size 10 10 10`
- 3D view with grid override: `uv run python src/visualize_trained.py --model-path <run_dir> --grid-size 10 10 10`
- 3D plotly view with mouse controls + playback: `uv run python src/visualize_trained.py --model-path <run_dir> --plotly`
- Plotly playback speed: `uv run python src/visualize_trained.py --model-path <run_dir> --plotly --plotly-speed 0.5`

3D visualization options (viz-only) include:
`--viz-slice-axis`, `--viz-slice-stride`, `--viz-slice-spacing`, `--viz-slice-alpha`, `--viz-max-slices`
for training, and `--slice-axis`, `--slice-stride`, `--slice-spacing`, `--slice-alpha`, `--max-slices`
for `visualize_trained.py`.

## Configs

For additional configurations, you can load a JSON config file. Any parameters not specified in the config file will be set their default value in `src/config.py`
For additional configurations, you can load a JSON config file. Any parameters not specified in the config file will be set their default value in `src/config.py`.
`grid_size` is `(D, H, W)`; 2D configs `(H, W)` are still accepted and promoted to `(1, H, W)`.
8 changes: 7 additions & 1 deletion configs/example.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"grid_size": [
1,
50,
50
],
Expand All @@ -26,11 +27,16 @@
"epochs": 1000,
"log_every": 10,
"wandb": false,
"viz_slice_axis": "depth",
"viz_slice_stride": 1,
"viz_slice_spacing": 1.2,
"viz_slice_alpha": 0.9,
"viz_max_slices": null,
"sun_update_epoch_wait": 0,
"steps_before_update": 0,
"steps_per_update": 1,
"device": "mps",
"seed": 42,
"mode": "train",
"run_name": "debug"
}
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"nbstripout>=0.8.1",
"numpy>=2.3.3",
"pillow>=11.3.0",
"plotly>=6.5.2",
"torch>=2.8.0",
"wandb>=0.21.3",
]
27 changes: 23 additions & 4 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class Config:
validation, and seed management for reproducible experiments.
"""

# Grid
grid_size: tuple[int, int] = (10, 10)
# Grid (D, H, W); 2D configs (H, W) are promoted to (1, H, W)
grid_size: tuple[int, int] | tuple[int, int, int] = (10, 10)
n_seeds: int = 1

# World
Expand Down Expand Up @@ -53,6 +53,14 @@ class Config:
epochs: int = 1_000
log_every: int = 100
wandb: bool = False
live_viz: bool = False

# Visualization (viz-only; does not affect training)
viz_slice_axis: Literal["depth", "height", "width"] = "depth"
viz_slice_stride: int = 1
viz_slice_spacing: float = 1.2
viz_slice_alpha: float = 0.9
viz_max_slices: int | None = None

# Sun
sun_update_epoch_wait: int = 0
Expand All @@ -75,12 +83,23 @@ def __post_init__(self) -> None:
Raises:
AssertionError: If cell_state_dim is not even or batch_size > pool_size.
"""
if isinstance(self.grid_size, list):
object.__setattr__(self, "grid_size", tuple(self.grid_size))
if len(self.grid_size) == 2:
object.__setattr__(self, "grid_size", (1, *self.grid_size))
assert len(self.grid_size) == 3, "[config] grid_size must be (D, H, W)"

assert self.cell_state_dim % 2 == 0, "[config] cell_state_dim must be even"
assert self.batch_size <= self.pool_size, "[config] batch_size > pool_size"
assert self.n_seeds * self.n_ncas <= self.total_grid_size, (
"[config] n_seeds * n_ncas > self.total_grid_size"
)
assert self.softmax_temp > 0, "[config] softmax_temp <= 0"
assert self.viz_slice_stride > 0, "[config] viz_slice_stride must be > 0"
assert self.viz_slice_spacing > 0, "[config] viz_slice_spacing must be > 0"
assert 0.0 <= self.viz_slice_alpha <= 1.0, "[config] viz_slice_alpha must be [0, 1]"
if self.viz_max_slices is not None:
assert self.viz_max_slices > 0, "[config] viz_max_slices must be > 0"

# Device availability check
if self.device == "cuda" and not torch.cuda.is_available():
Expand Down Expand Up @@ -142,9 +161,9 @@ def total_grid_size(self) -> int:
"""Total number of cells in the grid.

Returns:
Product of grid dimensions (width * height).
Product of grid dimensions (depth * height * width).
"""
return self.grid_size[0] * self.grid_size[1]
return self.grid_size[0] * self.grid_size[1] * self.grid_size[2]

@classmethod
def from_file(cls, path: str) -> "Config":
Expand Down
90 changes: 45 additions & 45 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.DC = dropout_chance

self.encode = nn.Sequential(
nn.Conv2d(
nn.Conv3d(
self.C,
self.N * self.HD,
self.KS,
Expand All @@ -82,15 +82,15 @@ def __init__(
)
self.reasoning = nn.Sequential(*[self.mid_conv_block() for _ in range(self.NH)])
self.compression = nn.Sequential(
nn.Conv2d(self.N * self.HD, self.N * self.OC, 1, groups=self.N, bias=False),
nn.Conv3d(self.N * self.HD, self.N * self.OC, 1, groups=self.N, bias=False),
nn.Tanh(),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the merged CA model.

Args:
x: Input tensor [B, C, H, W].
x: Input tensor [B, C, D, H, W].

Returns:
Updated tensor of same shape as input.
Expand All @@ -104,7 +104,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def mid_conv_block(self):
return Skipper(
nn.Sequential(
nn.Conv2d(
nn.Conv3d(
self.N * self.HD,
self.N * self.HD,
self.KS,
Expand Down Expand Up @@ -251,7 +251,7 @@ def __init__(self, config: Config) -> None:
self.threshold = torch.tensor(0.4, device=config.device)
self.perspective_mask = torch.eye(self.N, device=self.device)
self.perspective_mask = rearrange(
self.perspective_mask, "n1 n2 -> n1 n2 () () () ()"
self.perspective_mask, "n1 n2 -> n1 n2 () () () () ()"
)

self.per_hid_upd = config.per_hid_upd
Expand All @@ -275,7 +275,7 @@ def _setup_sun_update(self, config: Config) -> None:
sun_vec[self.hidden_idxs - self.N] = 0.0
sun_vec /= sun_vec.norm()

self.sun_update = rearrange(sun_vec, "oc -> () oc () ()")
self.sun_update = rearrange(sun_vec, "oc -> () oc () () ()")
self.sun_update.requires_grad = True

self.sun_optim = torch.optim.AdamW([self.sun_update], lr=config.learning_rate)
Expand All @@ -289,15 +289,15 @@ def _parallel_forward_step(self, x_perspectives: torch.Tensor) -> torch.Tensor:
"""Single forward step for all perspectives in parallel.

Args:
x_perspectives: Individual perspectives [N, B, C, H, W]
x_perspectives: Individual perspectives [N, B, C, D, H, W]
where each should only have gradients for NCA ni.

Returns:
Updated perspectives tensor of same shape.
"""
N, B, C, H, W = x_perspectives.shape
N, B, C, D, H, W = x_perspectives.shape

x_flat = rearrange(x_perspectives, "n b c h w -> (n b) c h w")
x_flat = rearrange(x_perspectives, "n b c d h w -> (n b) c d h w")
if not self.alive_visible:
x_flat = x_flat[:, self.cell_idxs]

Expand All @@ -315,18 +315,18 @@ def _parallel_forward_step(self, x_perspectives: torch.Tensor) -> torch.Tensor:
)
vis_grid[:, : self.cell_state_dim] = 1
vis_grid[xs, ys + self.cell_state_dim] = 1
vis_grid = rearrange(vis_grid, "n c -> () n () c () ()")
vis_grid = rearrange(vis_grid, "n c -> () n () c () () ()")

all_updates = self.models(x_flat) # [N*B, OC*N, H, W]
all_updates = self.models(x_flat) # [N*B, OC*N, D, H, W]
all_updates = rearrange(
all_updates,
"(n b) (oc m) h w -> n m b oc h w",
"(n b) (oc m) d h w -> n m b oc d h w",
n=self.N,
m=self.n_ncas,
)
all_updates = all_updates * vis_grid

sun_update = self.sun_update.expand(self.N, 1, B, self.out_dim, H, W)
sun_update = self.sun_update.expand(self.N, 1, B, self.out_dim, D, H, W)
all_updates = torch.cat([sun_update, all_updates], dim=1) # [N, M, B, OC, H, W]

all_updates = all_updates * self.perspective_mask + all_updates.detach() * (
Expand All @@ -348,23 +348,23 @@ def _run_competition_parallel(
"""Fully parallel competition across all perspectives.

Args:
x_perspectives: Input perspectives [N, B, C, H, W].
x_perspectives: Input perspectives [N, B, C, D, H, W].
all_updates: All proposed updates [N, M, B, OC, H, W] where M=N(n_ncas)+1.
NOTE: In this new one where the sun gets updated as well, N and M are both self.N

Returns:
Updated perspectives after competition resolution.
"""
N, B, C, H, W = x_perspectives.shape
N, B, C, D, H, W = x_perspectives.shape

# Since all perspectives are the same, can you just get it for one and that's good enough?
alive_mask_flat = self._get_nca_alive_mask(x_perspectives[0]) # [M, B, H, W]
alive_mask = repeat(alive_mask_flat, "m b h w -> n m b h w", n=N)
alive_mask_flat = self._get_nca_alive_mask(x_perspectives[0]) # [M, B, D, H, W]
alive_mask = repeat(alive_mask_flat, "m b d h w -> n m b d h w", n=N)

all_updates = all_updates * rearrange(alive_mask, "n m b h w -> n m b 1 h w")
all_updates = all_updates * rearrange(alive_mask, "n m b d h w -> n m b 1 d h w")

all_attacks = all_updates[:, :, :, self.att_idxs] # [N, M, B, C_att, H, W]
all_defenses = all_updates[:, :, :, self.def_idxs] # [N, M, B, C_def, H, W]
all_attacks = all_updates[:, :, :, self.att_idxs] # [N, M, B, C_att, D, H, W]
all_defenses = all_updates[:, :, :, self.def_idxs] # [N, M, B, C_def, D, H, W]

att_alive = alive_mask_flat[self.interactions[:, 0]]
def_alive = alive_mask_flat[self.interactions[:, 1]]
Expand All @@ -378,24 +378,24 @@ def _run_competition_parallel(
# cos_sim = F.cosine_similarity(attacks, defenses, dim=3) # [N, I, B, H, W]
cos_sim = F.cosine_similarity(
all_attacks[:, att_idx], all_defenses[:, def_idx], dim=3
) # [N, I, B, H, W]
) # [N, I, B, D, H, W]
defense_cos_sim = F.cosine_similarity(
all_attacks[:, def_idx], all_defenses[:, att_idx], dim=3
) # [N, I, B, H, W]
) # [N, I, B, D, H, W]

if self.mode == "eval":
cos_sim_reduced = reduce(cos_sim.detach(), "n i b h w -> i", "mean")
cos_sim_reduced = reduce(cos_sim.detach(), "n i b d h w -> i", "mean")
self.inter = torch.zeros(self.N, self.N, device=self.device)
self.inter[
self.interactions[touching_ncas, 0], self.interactions[touching_ncas, 1]
] = cos_sim_reduced

strengths = einsum(
cos_sim, self.str_add_idx[touching_ncas], "n i b h w, i m -> n m b h w"
cos_sim, self.str_add_idx[touching_ncas], "n i b d h w, i m -> n m b d h w"
) - einsum(
defense_cos_sim,
self.str_add_idx[touching_ncas],
"n i b h w, i m -> n m b h w",
"n i b d h w, i m -> n m b d h w",
)

# Apply alive mask to this!
Expand All @@ -404,15 +404,15 @@ def _run_competition_parallel(

x_new = torch.zeros_like(x_perspectives)
x_new[:, :, self.cell_idxs] = x_perspectives[:, :, self.cell_idxs] + einsum(
all_updates, strengths, "n m b c h w, n m b h w -> n b c h w"
all_updates, strengths, "n m b c d h w, n m b d h w -> n b c d h w"
)

x_new[:, :, self.ali_idxs] = rearrange(strengths, "n m b h w -> n b m h w").to(
x_new[:, :, self.ali_idxs] = rearrange(strengths, "n m b d h w -> n b m d h w").to(
x_new.dtype
)

x_new[:, :, self.ali_idxs] = torch.where(
rearrange(alive_mask, "n m b h w -> n b m h w"),
rearrange(alive_mask, "n m b d h w -> n b m d h w"),
x_new[:, :, self.ali_idxs],
-torch.inf,
)
Expand All @@ -421,16 +421,16 @@ def _run_competition_parallel(
x_new[:, :, self.ali_idxs] / self.softmax_temp, dim=2
).to(x_new.dtype)

alive_mask_flat = self._get_nca_alive_mask(x_new[0]) # [M, B, H, W]
alive_mask = repeat(alive_mask_flat, "m b h w -> n b m h w", n=N)
alive_mask_flat = self._get_nca_alive_mask(x_new[0]) # [M, B, D, H, W]
alive_mask = repeat(alive_mask_flat, "m b d h w -> n b m d h w", n=N)

# Kill off anything not alive enough
# You need to accomplish some baseline before being able to stay at some cell
x_new[:, :, self.ali_idxs] = x_new[:, :, self.ali_idxs] * alive_mask

# Distribute the remaining aliveness so that it sums to 1
x_new[:, :, self.ali_idxs] = x_new[:, :, self.ali_idxs] / (
reduce(x_new[:, :, self.ali_idxs], "n b c h w -> n b 1 h w", "sum")
reduce(x_new[:, :, self.ali_idxs], "n b c d h w -> n b 1 d h w", "sum")
).to(x_new.dtype)
return x_new

Expand All @@ -440,16 +440,16 @@ def __call__(
"""Run multiple forward steps while maintaining gradient isolation.

Args:
x: Input tensor [B, C, H, W].
x: Input tensor [B, C, D, H, W].
steps: Number of steps to run.

Returns:
Tuple containing:
- x_perspectives: Perspective grids [N, B, C, H, W]
- x_merged: Merged grid [B, C, H, W]
- x_perspectives: Perspective grids [N, B, C, D, H, W]
- x_merged: Merged grid [B, C, D, H, W]
- inter: Interaction statistics (currently None)
"""
x_perspectives = repeat(x, "b c h w -> n b c h w", n=self.N).clone()
x_perspectives = repeat(x, "b c d h w -> n b c d h w", n=self.N).clone()

all_xs = torch.zeros((steps, *x.shape), device=x.device, dtype=x.dtype)

Expand All @@ -466,19 +466,19 @@ def _get_nca_alive_mask(self, x_perspectives: torch.Tensor) -> torch.Tensor:
cell or its 3x3 neighborhood (using max pooling).

Args:
x_perspectives: Input tensor [N*B, C, H, W].
x_perspectives: Input tensor [N*B, C, D, H, W].

Returns:
Boolean mask [M, N*B, H, W] where M=N+1, indicating alive cells.
Boolean mask [M, N*B, D, H, W] where M=N+1, indicating alive cells.
"""

NB, C, H, W = x_perspectives.shape
NB, C, D, H, W = x_perspectives.shape

alive_channels = x_perspectives[:, self.ali_idxs] # [NB, M, H, W]
alive_flat = rearrange(alive_channels, "nb m h w -> (nb m) 1 h w")
alive_pooled = F.max_pool2d(alive_flat, 3, stride=1, padding=1)
alive_channels = x_perspectives[:, self.ali_idxs] # [NB, M, D, H, W]
alive_flat = rearrange(alive_channels, "nb m d h w -> (nb m) 1 d h w")
alive_pooled = F.max_pool3d(alive_flat, 3, stride=1, padding=1)
alive_mask = (
rearrange(alive_pooled, "(nb m) 1 h w -> m nb h w", m=self.N)
rearrange(alive_pooled, "(nb m) 1 d h w -> m nb d h w", m=self.N)
> self.threshold
)

Expand All @@ -496,15 +496,15 @@ def update_models(
gradients, and updates model parameters.

Args:
x_perspectives: Perspective grids [N, B, C, H, W]. (includes sun perspective)
x_perspectives: Perspective grids [N, B, C, D, H, W]. (includes sun perspective)

Returns:
Dictionary containing training statistics:
- growth: List of growth percentages for sun and each NCA
- grad_norms: List of gradient norms for monitoring
"""

M, B, C, H, W = x_perspectives.shape
M, B, C, D, H, W = x_perspectives.shape
N = M - 1 # Number of NCAs (excluding sun)

m_idxs = torch.arange(M, device=self.device)
Expand Down Expand Up @@ -542,7 +542,7 @@ def update_models(
# ------------------------------------

# Go down into batch
batch_alive = alivenesses.view(M, B, -1).sum(-1) # [N, B]
batch_alive = alivenesses.view(M, B, -1).sum(-1) # [M, B]

log_growth = torch.asinh(batch_alive + 1e-3).mean(1) # [N]
ind_losses = -log_growth
Expand Down
Loading