Skip to content

Commit 203a358

Browse files
mamtsingquic-mamta
authored andcommitted
FT logger
Signed-off-by: Mamta Singh <[email protected]>
1 parent 1ec0070 commit 203a358

File tree

5 files changed

+61
-102
lines changed

5 files changed

+61
-102
lines changed

QEfficient/cloud/finetune.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def load_model_and_tokenizer(
110110
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
111111
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
112112
"""
113-
logger.log_rank_zero(f"loading HuggingFace model for {train_config.model_name}")
113+
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
114114
pretrained_model_path = hf_download(train_config.model_name)
115115
if train_config.task_type == "seq_classification":
116116
model = AutoModelForSequenceClassification.from_pretrained(
@@ -149,8 +149,7 @@ def load_model_and_tokenizer(
149149
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logger.WARNING)
150150
model.resize_token_embeddings(len(tokenizer))
151151

152-
# FIXME (Meet): Cover below line inside the logger once it is implemented.
153-
print_model_size(model, train_config)
152+
print_model_size(model)
154153

155154
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
156155
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -301,7 +300,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
301300
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
302301
if train_config.enable_ddp:
303302
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
304-
_ = train(
303+
results = train(
305304
model,
306305
tokenizer,
307306
train_dataloader,
@@ -313,7 +312,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
313312
)
314313
if train_config.enable_ddp:
315314
dist.destroy_process_group()
316-
return
315+
return results
317316

318317

319318
if __name__ == "__main__":

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2222
delimiter=",",
2323
)
2424
except Exception as e:
25-
logger.error(
25+
logger.raise_runtimeerror(
2626
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
2727
)
2828
raise e

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
1313
from QEfficient.finetune.dataset.dataset_config import DATALOADER_COLLATE_FUNC, DATASET_PREPROC
14+
from QEfficient.finetune.utils.logging_utils import logger
1415

1516

1617
def get_preprocessed_dataset(
@@ -72,7 +73,7 @@ def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"
7273
print("custom_data_collator is used")
7374
dl_kwargs["collate_fn"] = custom_data_collator
7475

75-
print(f"length of dataset_{split}", len(dataset))
76+
logger.log_rank_zero(f"Length of {split} dataset is {len(dataset)}")
7677

7778
# Create data loader
7879
dataloader = torch.utils.data.DataLoader(

QEfficient/finetune/utils/logging_utils.py

Lines changed: 41 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,79 +14,44 @@
1414
from QEfficient.utils.constants import ROOT_DIR
1515

1616

17-
class QEffFormatter(logging.Formatter):
18-
"""
19-
Formatter class used to set colors for printing different logging levels of messages on console.
20-
"""
21-
22-
cyan: str = "\x1b[38;5;14m"
23-
yellow: str = "\x1b[33;20m"
24-
red: str = "\x1b[31;20m"
25-
bold_red: str = "\x1b[31;1m"
26-
reset: str = "\x1b[0m"
27-
common_format: str = "%(levelname)s - %(name)s - %(message)s" # type: ignore
28-
format_with_line_info = "%(levelname)s - %(name)s - %(message)s (%(filename)s:%(lineno)d)" # type: ignore
29-
30-
FORMATS = {
31-
logging.DEBUG: cyan + format_with_line_info + reset,
32-
logging.INFO: cyan + common_format + reset,
33-
logging.WARNING: yellow + common_format + reset,
34-
logging.ERROR: red + format_with_line_info + reset,
35-
logging.CRITICAL: bold_red + format_with_line_info + reset,
36-
}
37-
38-
def format(self, record):
39-
"""
40-
Overriding the base class method to Choose format based on log level.
41-
"""
42-
log_fmt = self.FORMATS.get(record.levelno)
43-
formatter = logging.Formatter(log_fmt)
44-
return formatter.format(record)
45-
46-
47-
def create_logger() -> logging.Logger:
48-
"""
49-
Creates a logger object with Colored QEffFormatter.
50-
"""
51-
logger = logging.getLogger("QEfficient")
52-
53-
# create console handler and set level
54-
ch = logging.StreamHandler()
55-
ch.setLevel(logging.INFO)
56-
ch.setFormatter(QEffFormatter())
57-
logger.addHandler(ch)
58-
59-
return logger
60-
61-
62-
class CustomLogger(logging.Logger):
63-
def raise_runtimeerror(self, message):
64-
self.error(message)
65-
raise RuntimeError(message)
66-
67-
def log_rank_zero(self, msg: str, level: int = logging.INFO) -> None:
68-
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
69-
if rank != 0:
70-
return
71-
self.log(level, msg, stacklevel=2)
72-
73-
def prepare_dump_logs(self, dump_logs=False):
74-
if dump_logs:
75-
logs_path = os.path.join(ROOT_DIR, "logs")
76-
if not os.path.exists(logs_path):
77-
os.makedirs(logs_path, exist_ok=True)
78-
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
79-
log_file = os.path.join(logs_path, file_name)
80-
81-
# create file handler and set level
82-
fh = logging.FileHandler(log_file)
83-
fh.setLevel(logging.INFO)
84-
formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
85-
fh.setFormatter(formatter)
86-
logger.addHandler(fh)
87-
88-
89-
logging.setLoggerClass(CustomLogger)
90-
91-
# Define the logger object that can be used for logging purposes throughout the module.
92-
logger = create_logger()
17+
class FTLogger:
18+
def __init__(self, level=logging.DEBUG):
19+
self.logger = logging.getLogger("QEfficient")
20+
if not getattr(self.logger, "_custom_methods_added", False):
21+
self._bind_custom_methods()
22+
self.logger._custom_methods_added = True # Prevent adding handlers/methods twice
23+
24+
def _bind_custom_methods(self):
25+
def raise_runtimeerror(message):
26+
self.logger.error(message)
27+
raise RuntimeError(message)
28+
29+
def log_rank_zero(msg: str, level: int = logging.INFO):
30+
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
31+
if rank != 0:
32+
return
33+
self.logger.log(level, msg, stacklevel=2)
34+
35+
def prepare_dump_logs(dump_logs=False, level=logging.INFO):
36+
if dump_logs:
37+
logs_path = os.path.join(ROOT_DIR, "logs")
38+
if not os.path.exists(logs_path):
39+
os.makedirs(logs_path, exist_ok=True)
40+
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
41+
log_file = os.path.join(logs_path, file_name)
42+
43+
fh = logging.FileHandler(log_file)
44+
fh.setLevel(level)
45+
formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s")
46+
fh.setFormatter(formatter)
47+
self.logger.addHandler(fh)
48+
49+
self.logger.raise_runtimeerror = raise_runtimeerror
50+
self.logger.log_rank_zero = log_rank_zero
51+
self.logger.prepare_dump_logs = prepare_dump_logs
52+
53+
def get_logger(self):
54+
return self.logger
55+
56+
57+
logger = FTLogger().get_logger()

QEfficient/finetune/utils/train_utils.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def train(
8484
max_steps_reached = False # Flag to indicate max training steps reached
8585

8686
tensorboard_updates = None
87-
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
87+
if (not train_config.enable_ddp) or (local_rank == 0):
8888
tensorboard_updates = SummaryWriter()
8989

9090
device_type = torch.device(device).type
@@ -215,7 +215,7 @@ def train(
215215
else:
216216
loss_0_counter = torch.tensor([0]).to(device)
217217

218-
if (not train_config.enable_ddp) or (train_config.enable_ddp and local_rank == 0):
218+
if (not train_config.enable_ddp) or (local_rank == 0):
219219
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
220220

221221
if train_config.save_metrics:
@@ -300,18 +300,10 @@ def train(
300300
lr_scheduler.step()
301301

302302
if train_config.run_validation:
303-
if train_config.enable_ddp:
304-
dist.barrier()
305-
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
306-
model, train_config, eval_dataloader, device
307-
)
308-
if local_rank == 0:
309-
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
310-
311-
else:
312-
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
313-
model, train_config, eval_dataloader, device
314-
)
303+
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
304+
model, train_config, eval_dataloader, device
305+
)
306+
if (not train_config.enable_ddp) or (local_rank == 0):
315307
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
316308

317309
if train_config.save_metrics:
@@ -385,6 +377,9 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
385377
386378
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
387379
"""
380+
if train_config.enable_ddp:
381+
dist.barrier()
382+
388383
model.eval()
389384

390385
if train_config.task_type == "seq_classification":
@@ -457,16 +452,15 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
457452
return longest_seq_length, longest_seq_ix
458453

459454

460-
def print_model_size(model, config) -> None:
455+
def print_model_size(model) -> None:
461456
"""
462457
Print model name, the number of trainable parameters and initialization time.
463458
464459
Args:
465-
model: The PyTorch model.
466-
config : Config of the model.
460+
model: PyTorch model.
467461
"""
468462
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
469-
logger.log_rank_zero(f"{config.model_name} has {total_params / 1e6} Million params.")
463+
logger.log_rank_zero(f"Model has {total_params / 1e6} Million params.")
470464

471465

472466
def print_trainable_parameters(model) -> None:
@@ -478,7 +472,7 @@ def print_trainable_parameters(model) -> None:
478472
"""
479473
trainable_params, all_param = model.get_nb_trainable_parameters()
480474
logger.log_rank_zero(
481-
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
475+
f"Trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}"
482476
)
483477

484478

0 commit comments

Comments
 (0)