Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Weight decay #265

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def test_default(self):
edge_paths=[], # filled in later
checkpoint_path=self.checkpoint_path.name,
workers=2,
wd=0.01,
wd_interval=2,
)
dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2])
self.addCleanup(dataset.cleanup)
Expand Down
17 changes: 17 additions & 0 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ class ConfigSchema(Schema):
regularizer: str = attr.ib(
default="N3", metadata={"help": "Type of regularization to be applied."}
)

wd : float = attr.ib(
default=0,
validator=non_negative,
metadata={"help": "Simple (unweighted) weight decay"},
)
wd_interval : int = attr.ib(
default=100,
validator=non_negative,
metadata={"help": "Interval to amortize weight decay"},
)

# data config

Expand Down Expand Up @@ -385,6 +396,12 @@ class ConfigSchema(Schema):
"after each training step."
},
)
early_stopping: bool = attr.ib(
default=False,
metadata={
"help": "Stop training when validation loss increases."
}
)

# expert options

Expand Down
14 changes: 14 additions & 0 deletions torchbiggraph/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def __init__(
],
comparator: AbstractComparator,
regularizer: AbstractRegularizer,
wd: float,
wd_interval: int,
global_emb: bool = False,
max_norm: Optional[float] = None,
num_dynamic_rels: int = 0,
Expand Down Expand Up @@ -444,6 +446,8 @@ def __init__(
self.max_norm: Optional[float] = max_norm
self.half_precision = half_precision
self.regularizer: Optional[AbstractRegularizer] = regularizer
self.wd = wd
self.wd_interval = wd_interval

def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None:
if self.entities[entity].featurized:
Expand Down Expand Up @@ -762,6 +766,14 @@ def forward(self, edges: EdgeList) -> Scores:
reg,
)


def l2_norm(self):
ret = 0
for e in set(self.lhs_embs.values()) | set(self.rhs_embs.values()):
ret += e.weight.pow(2).sum()
return ret


def forward_direction_agnostic(
self,
src: EntityList,
Expand Down Expand Up @@ -921,6 +933,8 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
rhs_operators=rhs_operators,
comparator=comparator,
regularizer=regularizer,
wd=config.wd,
wd_interval=config.wd_interval,
global_emb=config.global_emb,
max_norm=config.max_norm,
num_dynamic_rels=num_dynamic_rels,
Expand Down
99 changes: 59 additions & 40 deletions torchbiggraph/train_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import math
import time
import random
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
Expand Down Expand Up @@ -86,7 +87,13 @@ def __init__(
def _process_one_batch(
self, model: MultiRelationEmbedder, batch_edges: EdgeList
) -> Stats:
model.zero_grad()
# Tricky: this isbasically like calling `model.zero_grad()` except
# that `zero_grad` calls `p.grad.zero_()`. When we perform infrequent
# global L2 regularization, it converts the embedding gradients to dense,
# and then they can never convert back to sparse gradients unless we set
# them to `None` again here.
for p in model.parameters():
p.grad = None

scores, reg = model(batch_edges)

Expand All @@ -100,9 +107,10 @@ def _process_one_batch(
count=len(batch_edges),
)
if reg is not None:
(loss + reg).backward()
else:
loss.backward()
loss = loss + reg
if model.wd > 0 and random.random() < 1. / model.wd_interval:
loss = loss + model.wd * model.wd_interval * model.l2_norm()
loss.backward()
self.model_optimizer.step(closure=None)
for optimizer in self.unpartitioned_optimizers.values():
optimizer.step(closure=None)
Expand Down Expand Up @@ -570,6 +578,7 @@ def train(self) -> None:
eval_stats_chunk_avg,
)

last_chunk_loss = float("inf")
for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager:
logger.info(
f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, "
Expand Down Expand Up @@ -721,10 +730,17 @@ def train(self) -> None:

current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1

self._maybe_write_checkpoint(
all_stats_dicts = self._maybe_write_checkpoint(
epoch_idx, edge_path_idx, edge_chunk_idx, current_index
)

if config.early_stopping:
assert iteration_manager.num_edge_paths == 1
chunk_loss = all_stats_dicts[-1]["eval_stats_chunk_avg"]["metrics"]["loss"]
if chunk_loss > last_chunk_loss:
break
last_chunk_loss = chunk_loss

# now we're sure that all partition files exist,
# so be strict about loading them
self.strict = True
Expand Down Expand Up @@ -914,7 +930,7 @@ def _maybe_write_checkpoint(
edge_path_idx: int,
edge_chunk_idx: int,
current_index: int,
) -> None:
) -> List[Dict[str, Any]]:

config = self.config

Expand Down Expand Up @@ -955,42 +971,43 @@ def _maybe_write_checkpoint(
state_dict, self.trainer.model_optimizer.state_dict()
)

logger.info("Writing the training stats")
all_stats_dicts: List[Dict[str, Any]] = []
bucket_eval_stats_list = []
chunk_stats_dict = {
"epoch_idx": epoch_idx,
"edge_path_idx": edge_path_idx,
"edge_chunk_idx": edge_chunk_idx,
all_stats_dicts: List[Dict[str, Any]] = []
bucket_eval_stats_list = []
chunk_stats_dict = {
"epoch_idx": epoch_idx,
"edge_path_idx": edge_path_idx,
"edge_chunk_idx": edge_chunk_idx,
}
for stats in self.bucket_scheduler.get_stats_for_pass():
stats_dict = {
"lhs_partition": stats.lhs_partition,
"rhs_partition": stats.rhs_partition,
"index": stats.index,
"stats": stats.train.to_dict(),
}
for stats in self.bucket_scheduler.get_stats_for_pass():
stats_dict = {
"lhs_partition": stats.lhs_partition,
"rhs_partition": stats.rhs_partition,
"index": stats.index,
"stats": stats.train.to_dict(),
}
if stats.eval_before is not None:
stats_dict["eval_stats_before"] = stats.eval_before.to_dict()
bucket_eval_stats_list.append(stats.eval_before)

if stats.eval_after is not None:
stats_dict["eval_stats_after"] = stats.eval_after.to_dict()

stats_dict.update(chunk_stats_dict)
all_stats_dicts.append(stats_dict)

if len(bucket_eval_stats_list) != 0:
eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list)
self.stats_handler.on_stats(
index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg
)
chunk_stats_dict["index"] = current_index
chunk_stats_dict[
"eval_stats_chunk_avg"
] = eval_stats_chunk_avg.to_dict()
all_stats_dicts.append(chunk_stats_dict)
if stats.eval_before is not None:
stats_dict["eval_stats_before"] = stats.eval_before.to_dict()
bucket_eval_stats_list.append(stats.eval_after)

if stats.eval_after is not None:
stats_dict["eval_stats_after"] = stats.eval_after.to_dict()

stats_dict.update(chunk_stats_dict)
all_stats_dicts.append(stats_dict)

if len(bucket_eval_stats_list) != 0:
eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list)
chunk_stats_dict["index"] = current_index
chunk_stats_dict[
"eval_stats_chunk_avg"
] = eval_stats_chunk_avg.to_dict()
all_stats_dicts.append(chunk_stats_dict)

if self.rank == 0:
logger.info("Writing the training stats")
self.stats_handler.on_stats(
index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg
)
self.checkpoint_manager.append_stats(all_stats_dicts)

logger.info("Writing the checkpoint")
Expand Down Expand Up @@ -1021,3 +1038,5 @@ def _maybe_write_checkpoint(
self.checkpoint_manager.preserve_current_version(config, epoch_idx + 1)
if not preserve_old_checkpoint:
self.checkpoint_manager.remove_old_version(config)

return all_stats_dicts