Skip to content

Commit 3d8a53e

Browse files
mamtsingquic-mamta
authored andcommitted
modify error handling
Signed-off-by: Mamta Singh <[email protected]>
1 parent 1e1519b commit 3d8a53e

File tree

12 files changed

+76
-60
lines changed

12 files changed

+76
-60
lines changed

QEfficient/cloud/finetune.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import logging
89
import random
910
import warnings
1011
from typing import Any, Dict, Optional, Union
@@ -40,7 +41,7 @@
4041
try:
4142
import torch_qaic # noqa: F401
4243
except ImportError as e:
43-
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
44+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.", logging.WARNING)
4445

4546

4647
# Suppress all warnings
@@ -121,7 +122,7 @@ def load_model_and_tokenizer(
121122
)
122123

123124
if not hasattr(model, "base_model_prefix"):
124-
logger.raise_runtimeerror("Given huggingface model does not have 'base_model_prefix' attribute.")
125+
logger.raise_error("Given huggingface model does not have 'base_model_prefix' attribute.", RuntimeError)
125126

126127
for param in getattr(model, model.base_model_prefix).parameters():
127128
param.requires_grad = False
@@ -146,7 +147,7 @@ def load_model_and_tokenizer(
146147
# If there is a mismatch between tokenizer vocab size and embedding matrix,
147148
# throw a warning and then expand the embedding matrix
148149
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
149-
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logger.WARNING)
150+
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.", logging.WARNING)
150151
model.resize_token_embeddings(len(tokenizer))
151152

152153
print_model_size(model)
@@ -161,8 +162,8 @@ def load_model_and_tokenizer(
161162
if hasattr(model, "supports_gradient_checkpointing") and model.supports_gradient_checkpointing:
162163
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"preserve_rng_state": False})
163164
else:
164-
logger.raise_runtimeerror(
165-
"Given model doesn't support gradient checkpointing. Please disable it and run it."
165+
logger.raise_error(
166+
"Given model doesn't support gradient checkpointing. Please disable it and run it.", RuntimeError
166167
)
167168

168169
model = apply_peft(model, train_config, peft_config_file, **kwargs)
@@ -237,8 +238,9 @@ def setup_dataloaders(
237238
if train_config.run_validation:
238239
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="val")
239240
if len(eval_dataloader) == 0:
240-
logger.raise_runtimeerror(
241-
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
241+
logger.raise_error(
242+
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
243+
ValueError,
242244
)
243245
else:
244246
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")
@@ -280,8 +282,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
280282
dataset_config = generate_dataset_config(train_config.dataset)
281283
update_config(dataset_config, **kwargs)
282284

283-
logger.prepare_dump_logs(train_config.dump_logs)
284-
logger.setLevel(train_config.log_level)
285+
logger.prepare_for_logs(train_config.output_dir, train_config.dump_logs, train_config.log_level)
285286

286287
setup_distributed_training(train_config)
287288
setup_seeds(train_config.seed)

QEfficient/finetune/configs/training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class TrainConfig:
9595
use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time.
9696
# profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler
9797

98-
dump_root_dir: str = "mismatches/step_"
9998
opByOpVerifier: bool = False
10099

101100
dump_logs: bool = True

QEfficient/finetune/dataset/alpaca_dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from torch.utils.data import Dataset
1313

14+
from QEfficient.finetune.utils.logging_utils import logger
15+
1416
PROMPT_DICT = {
1517
"prompt_input": (
1618
"Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -27,7 +29,13 @@
2729

2830
class InstructionDataset(Dataset):
2931
def __init__(self, dataset_config, tokenizer, partition="train", context_length=None):
30-
self.ann = json.load(open(dataset_config.data_path))
32+
try:
33+
self.ann = json.load(open(dataset_config.data_path))
34+
except FileNotFoundError:
35+
logger.raise_error(
36+
"Loading of alpaca dataset failed! Please use (wget -c https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/refs/heads/main/alpaca_data.json -P dataset/) to download the alpaca dataset.",
37+
FileNotFoundError,
38+
)
3139
# Use 5% of the dataset for evaluation
3240
eval_length = int(len(self.ann) / 20)
3341
if partition == "train":

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=Non
3232
module_path, func_name = dataset_config.file, "get_custom_dataset"
3333

3434
if not module_path.endswith(".py"):
35-
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
35+
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
3636

3737
module_path = Path(module_path)
3838
if not module_path.is_file():
39-
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
39+
logger.raise_error(
40+
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
41+
)
4042

4143
module = load_module_from_py_file(module_path.as_posix())
4244
try:
4345
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
4446
except AttributeError:
45-
logger.raise_runtimeerror(
46-
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
47+
logger.raise_error(
48+
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).",
49+
AttributeError,
4750
)
4851

4952

@@ -54,11 +57,13 @@ def get_data_collator(dataset_processer, dataset_config):
5457
module_path, func_name = dataset_config.file, "get_data_collator"
5558

5659
if not module_path.endswith(".py"):
57-
logger.raise_runtimeerror(f"Dataset file {module_path} is not a .py file.")
60+
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
5861

5962
module_path = Path(module_path)
6063
if not module_path.is_file():
61-
logger.raise_runtimeerror(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
64+
logger.raise_error(
65+
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
66+
)
6267

6368
module = load_module_from_py_file(module_path.as_posix())
6469
try:

QEfficient/finetune/dataset/grammar_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def __init__(self, tokenizer, csv_name=None, context_length=None):
2121
data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
2222
delimiter=",",
2323
)
24-
except Exception as e:
25-
logger.raise_runtimeerror(
26-
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset."
24+
except FileNotFoundError:
25+
logger.raise_error(
26+
"Loading of grammar dataset failed! Please check (https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb) for details on how to download the dataset.",
27+
FileNotFoundError,
2728
)
28-
raise e
2929

3030
self.context_length = context_length
3131
self.tokenizer = tokenizer

QEfficient/finetune/eval.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
device = "qaic:0"
2828
except ImportError as e:
29-
logger.warning(f"{e}. Moving ahead without these qaic modules.")
29+
logger.log_rank_zero(f"{e}. Moving ahead without these qaic modules.")
3030
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3131

3232
# Suppress all warnings
@@ -78,16 +78,17 @@ def main(**kwargs):
7878
# If there is a mismatch between tokenizer vocab size and embedding matrix,
7979
# throw a warning and then expand the embedding matrix
8080
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
81-
logger.warning("Resizing the embedding matrix to match the tokenizer vocab size.")
81+
logger.log_rank_zero("Resizing the embedding matrix to match the tokenizer vocab size.")
8282
model.resize_token_embeddings(len(tokenizer))
8383

8484
print_model_size(model)
8585

8686
if train_config.run_validation:
8787
eval_dataloader = get_dataloader(tokenizer, dataset_config, train_config, split="test")
8888
if len(eval_dataloader) == 0:
89-
raise ValueError(
90-
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
89+
logger.raise_error(
90+
f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})",
91+
ValueError,
9192
)
9293
else:
9394
logger.log_rank_zero(f"Number of Validation Set Batches loaded = {len(eval_dataloader)}")

QEfficient/finetune/utils/config_utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def update_config(config, **kwargs):
4444
if hasattr(config, param_name):
4545
setattr(config, param_name, v)
4646
else:
47-
raise ValueError(f"Config '{config_name}' does not have parameter: '{param_name}'")
47+
logger.raise_error(
48+
f"Config '{config_name}' does not have parameter: '{param_name}'", ValueError
49+
)
4850
else:
4951
config_type = type(config).__name__
5052
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
@@ -70,7 +72,7 @@ def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None
7072
else:
7173
config_map = {"lora": (LoraConfig, PeftLoraConfig)}
7274
if train_config.peft_method not in config_map:
73-
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
75+
logger.raise_error(f"Peft config not found: {train_config.peft_method}", RuntimeError)
7476

7577
config_cls, peft_config_cls = config_map[train_config.peft_method]
7678
if config_cls is None:
@@ -119,7 +121,7 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
119121
- Ensures types match expected values (int, float, list, etc.).
120122
"""
121123
if config_type.lower() != "lora":
122-
raise ValueError(f"Unsupported config_type: {config_type}. Only 'lora' is supported.")
124+
logger.raise_error(f"Unsupported config_type: {config_type}. Only 'lora' is supported.", ValueError)
123125

124126
required_fields = {
125127
"r": int,
@@ -136,26 +138,28 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
136138
# Check for missing required fields
137139
missing_fields = [field for field in required_fields if field not in config_data]
138140
if missing_fields:
139-
raise ValueError(f"Missing required fields in {config_type} config: {missing_fields}")
141+
logger.raise_error(f"Missing required fields in {config_type} config: {missing_fields}", ValueError)
140142

141143
# Validate types of required fields
142144
for field, expected_type in required_fields.items():
143145
if not isinstance(config_data[field], expected_type):
144-
raise ValueError(
146+
logger.raise_error(
145147
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
146-
f"got {type(config_data[field]).__name__}"
148+
f"got {type(config_data[field]).__name__}",
149+
ValueError,
147150
)
148151

149152
# Validate target_modules contains strings
150153
if not all(isinstance(mod, str) for mod in config_data["target_modules"]):
151-
raise ValueError("All elements in 'target_modules' must be strings")
154+
logger.raise_error("All elements in 'target_modules' must be strings", ValueError)
152155

153156
# Validate types of optional fields if present
154157
for field, expected_type in optional_fields.items():
155158
if field in config_data and not isinstance(config_data[field], expected_type):
156-
raise ValueError(
159+
logger.raise_error(
157160
f"Field '{field}' in {config_type} config must be of type {expected_type.__name__}, "
158-
f"got {type(config_data[field]).__name__}"
161+
f"got {type(config_data[field]).__name__}",
162+
ValueError,
159163
)
160164

161165

@@ -173,12 +177,12 @@ def load_config_file(config_path: str) -> Dict[str, Any]:
173177
ValueError: If the file format is unsupported.
174178
"""
175179
if not os.path.exists(config_path):
176-
raise FileNotFoundError(f"Config file not found: {config_path}")
180+
logger.raise_error(f"Config file not found: {config_path}", FileNotFoundError)
177181

178182
with open(config_path, "r") as f:
179183
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
180184
return yaml.safe_load(f)
181185
elif config_path.endswith(".json"):
182186
return json.load(f)
183187
else:
184-
raise ValueError("Unsupported config file format. Use .yaml, .yml, or .json")
188+
logger.raise_error("Unsupported config file format. Use .yaml, .yml, or .json", ValueError)

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def get_preprocessed_dataset(
1818
tokenizer, dataset_config, split: str = "train", context_length: int = None
1919
) -> torch.utils.data.Dataset:
2020
if dataset_config.dataset not in DATASET_PREPROC:
21-
raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
21+
logger.raise_error(f"{dataset_config.dataset} is not (yet) implemented", NotImplementedError)
2222

2323
def get_split():
2424
return dataset_config.train_split if split == "train" else dataset_config.test_split
@@ -39,8 +39,9 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
3939
if train_config.enable_ddp:
4040
if train_config.enable_sorting_for_ddp:
4141
if train_config.context_length:
42-
raise ValueError(
43-
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
42+
logger.raise_error(
43+
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding",
44+
ValueError,
4445
)
4546
else:
4647
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(

QEfficient/finetune/utils/logging_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010
from datetime import datetime
1111

1212
from QEfficient.finetune.utils.helper import is_rank_zero
13-
from QEfficient.utils.constants import ROOT_DIR
1413

1514

1615
class FTLogger:
17-
def __init__(self, level=logging.DEBUG):
16+
def __init__(self):
1817
self.logger = logging.getLogger("QEfficient")
1918
if not getattr(self.logger, "_custom_methods_added", False):
2019
self._bind_custom_methods()
2120
self.logger._custom_methods_added = True # Prevent adding handlers/methods twice
2221

2322
def _bind_custom_methods(self):
24-
def raise_runtimeerror(message):
23+
def raise_error(message, errortype=RuntimeError):
2524
self.logger.error(message)
26-
raise RuntimeError(message)
25+
raise errortype(message)
2726

2827
def log_rank_zero(msg: str, level: int = logging.INFO):
2928
if not is_rank_zero:
3029
return
3130
self.logger.log(level, msg, stacklevel=2)
3231

33-
def prepare_dump_logs(dump_logs=False, level=logging.INFO):
32+
def prepare_for_logs(output_path, dump_logs=False, level=logging.INFO):
33+
self.logger.setLevel(level)
3434
if dump_logs:
35-
logs_path = os.path.join(ROOT_DIR, "logs")
35+
logs_path = os.path.join(output_path, "logs")
3636
if not os.path.exists(logs_path):
3737
os.makedirs(logs_path, exist_ok=True)
3838
file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt"
@@ -44,9 +44,9 @@ def prepare_dump_logs(dump_logs=False, level=logging.INFO):
4444
fh.setFormatter(formatter)
4545
self.logger.addHandler(fh)
4646

47-
self.logger.raise_runtimeerror = raise_runtimeerror
47+
self.logger.raise_error = raise_error
4848
self.logger.log_rank_zero = log_rank_zero
49-
self.logger.prepare_dump_logs = prepare_dump_logs
49+
self.logger.prepare_for_logs = prepare_for_logs
5050

5151
def get_logger(self):
5252
return self.logger

QEfficient/finetune/utils/parser.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,14 @@ def get_finetune_parser():
254254
action="store_true",
255255
help="Enable distributed data parallel training. This will load the replicas of model on given number of devices and train the model. This should be used using torchrun interface. Please check docs for exact usage.",
256256
)
257-
parser.add_argument(
258-
"--dump_root_dir",
259-
"--dump-root-dir",
260-
required=False,
261-
type=str,
262-
default="mismatches/step_",
263-
help="Directory for mismatch dumps by opByOpVerifier",
264-
)
265257
parser.add_argument(
266258
"--opByOpVerifier",
267259
action="store_true",
268-
help="Enable operation-by-operation verification w.r.t reference device(cpu). It is a context manager interface that captures and verifies each operator against reference device. In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.",
260+
help=argparse.SUPPRESS,
261+
# This is for debugging purpose only.
262+
# Enables operation-by-operation verification w.r.t reference device(cpu).
263+
# It is a context manager interface that captures and verifies each operator against reference device.
264+
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.
269265
)
270266

271267
return parser

QEfficient/finetune/utils/plot_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
6969

7070
def plot_metrics(file_path):
7171
if not os.path.exists(file_path):
72-
logger.error(f"File {file_path} does not exist.")
72+
logger.raise_error(f"File {file_path} does not exist.", FileNotFoundError)
7373
return
7474

7575
with open(file_path, "r") as f:
7676
try:
7777
data = json.load(f)
7878
except json.JSONDecodeError:
79-
logger.error("Invalid JSON file.")
79+
logger.raise_error("Invalid JSON file.", json.JSONDecodeError)
8080
return
8181

8282
directory = os.path.dirname(file_path)

QEfficient/finetune/utils/train_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ def train(
8585
max_steps_reached = False # Flag to indicate max training steps reached
8686

8787
tensorboard_updates = None
88+
tensorboard_log_dir = train_config.output_dir + "/runs/" + f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
8889
if is_rank_zero():
89-
tensorboard_updates = SummaryWriter()
90+
tensorboard_updates = SummaryWriter(log_dir=tensorboard_log_dir)
9091

9192
device_type = torch.device(device).type
9293

@@ -181,7 +182,7 @@ def train(
181182
atol=1e-1,
182183
use_ref_output_on_mismatch=True,
183184
filter_config=qaic_debug.DispatchFilterConfig.default(device),
184-
dump_root_dir=train_config.dump_root_dir + str(step),
185+
dump_root_dir=train_config.output_dir + "/mismatches/step_" + str(step),
185186
) as verifier:
186187
model_outputs = model(**batch)
187188
loss = model_outputs.loss # Forward call

0 commit comments

Comments
 (0)