Skip to content

Commit 126a7ab

Browse files
committed
Yet another update for metrics: allow to compute metrics on a random subset (with fixed 42 seed)
1 parent 8fba20c commit 126a7ab

File tree

2 files changed

+75
-20
lines changed

2 files changed

+75
-20
lines changed

compute_metrics.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import hydra
77
import jsonlines
8+
import numpy as np
89
import wandb
910
from hydra.utils import to_absolute_path
1011
from omegaconf import OmegaConf
@@ -16,6 +17,7 @@
1617
logger = logging.getLogger("datasets")
1718
logger.setLevel(logging.ERROR)
1819
random.seed(42)
20+
np.random.seed(42)
1921

2022

2123
def load_predictions(cfg: MetricsConfig) -> str:
@@ -104,7 +106,8 @@ def main(cfg: MetricsConfig):
104106
job_type="metrics" if not cfg.filter.use_filtering else "filter_metrics",
105107
tags=(["new_prefix_logic"] if cfg.include_short else [])
106108
+ (["only_filtered" if cfg.filter.fit_filters else "only_unfiltered"] if cfg.filter.use_filtering else [])
107-
+ (["subset"] if cfg.filter.use_pos_in_file_filtering else []),
109+
+ (["specific_subset"] if cfg.filter.use_pos_in_file_filtering else [])
110+
+ ([f"random_subset_{cfg.filter.subset_num_examples}"] if cfg.filter.use_subset else []),
108111
) # type: ignore[assignment]
109112
cfg.preds_path = load_predictions(cfg)
110113
elif cfg.preds_path:
@@ -125,11 +128,35 @@ def main(cfg: MetricsConfig):
125128

126129
# default: simply compute the metrics for all the examples
127130
if not cfg.filter.use_filtering:
128-
with jsonlines.open(cfg.preds_path, "r") as reader:
129-
for line in tqdm(reader, desc="Computing metrics"):
130-
add_single_example(
131-
line, full_metrics=full_metrics, prefix_metrics=prefix_metrics, include_short=cfg.include_short
132-
)
131+
# or for a subset of N examples
132+
if cfg.filter.use_subset:
133+
assert (
134+
cfg.filter.subset_num_examples is not None
135+
), "Configured to use subset, but the desired number of examples is None."
136+
logging.info(f"Will consider random subset of {cfg.filter.subset_num_examples} examples.")
137+
138+
with jsonlines.open(cfg.preds_path, "r") as reader:
139+
num_examples = sum(1 for _ in reader)
140+
subset_ids = set(np.random.choice(num_examples, size=cfg.filter.subset_num_examples, replace=False))
141+
142+
with jsonlines.open(cfg.preds_path, "r") as reader:
143+
for i, line in tqdm(
144+
enumerate(reader),
145+
desc=f"Computing metrics on a random subset of {cfg.filter.subset_num_examples} examples",
146+
):
147+
if i in subset_ids:
148+
add_single_example(
149+
line,
150+
full_metrics=full_metrics,
151+
prefix_metrics=prefix_metrics,
152+
include_short=cfg.include_short,
153+
)
154+
else:
155+
with jsonlines.open(cfg.preds_path, "r") as reader:
156+
for line in tqdm(reader, desc="Computing metrics"):
157+
add_single_example(
158+
line, full_metrics=full_metrics, prefix_metrics=prefix_metrics, include_short=cfg.include_short
159+
)
133160

134161
# or define filters configuration to control what subset will be considered
135162
# option 1: boolean filters
@@ -160,24 +187,49 @@ def include_example(filters_line: Dict[str, str]) -> bool:
160187
logging.warning(
161188
f"Total number of examples: {num_total}, will consider {num_included} examples ({num_included / num_total * 100 :.2f}%)."
162189
)
163-
164-
with jsonlines.open(cfg.preds_path, "r") as reader:
190+
# or for a subset of N examples
191+
if cfg.filter.use_subset:
192+
assert (
193+
cfg.filter.subset_num_examples is not None
194+
), "Configured to use subset, but the desired number of examples is None."
195+
assert (
196+
cfg.filter.subset_num_examples >= num_included
197+
), "Configured to use subset, but the desired number of examples is larger than the total sample."
198+
199+
logging.info(f"Will consider random subset of {cfg.filter.subset_num_examples} examples.")
165200
with jsonlines.open(cfg.filter.path, "r") as filters_reader:
166-
for i, (input_line, filters_line) in tqdm(
167-
enumerate(zip(reader, filters_reader)), desc="Computing metrics with filters"
201+
included_ids = [i for i, filters_line in enumerate(filters_reader) if include_example(filters_line)]
202+
subset_ids = set(np.random.choice(included_ids, size=cfg.filter.subset_num_examples, replace=False))
203+
204+
with jsonlines.open(cfg.preds_path, "r") as reader:
205+
for i, line in tqdm(
206+
enumerate(reader),
207+
desc=f"Computing metrics with filters on a random subset of {cfg.filter.subset_num_examples} examples",
168208
):
169-
if include_example(filters_line):
209+
if i in subset_ids:
170210
add_single_example(
171-
input_line,
211+
line,
172212
full_metrics=full_metrics,
173213
prefix_metrics=prefix_metrics,
174214
include_short=cfg.include_short,
175215
)
216+
else:
217+
with jsonlines.open(cfg.preds_path, "r") as reader:
218+
with jsonlines.open(cfg.filter.path, "r") as filters_reader:
219+
for i, (input_line, filters_line) in tqdm(
220+
enumerate(zip(reader, filters_reader)), desc="Computing metrics with filters"
221+
):
222+
if include_example(filters_line):
223+
add_single_example(
224+
input_line,
225+
full_metrics=full_metrics,
226+
prefix_metrics=prefix_metrics,
227+
include_short=cfg.include_short,
228+
)
176229

177230
# option 2: pos in file-filtering (only include examples that are present in a given file, controlled by `pos_in_file` column)
178231
else:
179-
logging.info("Will compute metrics on a given subset.")
180-
232+
logging.info("Will compute metrics on a specific given subset.")
181233
with jsonlines.open(cfg.filter.path, "r") as filters_reader:
182234
ids_to_include = set(line["pos_in_file"] for line in filters_reader)
183235

@@ -189,7 +241,7 @@ def include_example(filters_line: Dict[str, str]) -> bool:
189241
)
190242

191243
with jsonlines.open(cfg.preds_path, "r") as reader:
192-
for i, input_line in tqdm(enumerate(reader), desc="Computing metrics on a given subset"):
244+
for i, input_line in tqdm(enumerate(reader), desc="Computing metrics on a specific given subset"):
193245
if i in ids_to_include:
194246
add_single_example(
195247
input_line,

conf/metrics_config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,27 @@ class FilterConfig:
1111
Configuration for additional data filtering when calculating metrics.
1212
1313
Attributes:
14-
use_filtering: True to use additional data filtering, False otherwise.
1514
path: Path to file with filters metadata for a test set.
16-
use_pos_in_file_filtering: True to use `pos_in_file` column and only consider lines present in a given file,
17-
False to use boolean filters logic.
15+
use_filtering: True to use additional data filtering, False otherwise.
1816
filters_to_include: List of column names to consider. Each column should be boolean.
1917
logic: A logic to follow when multiple columns are given (`and` for logical and, `or` for logical or).
2018
fit_filters: If True, will consider examples that fit given columns with given logic.
2119
If False, will consider examples that DON'T FIT given columns with given logic.
20+
use_pos_in_file_filtering: True to use `pos_in_file` column and only consider lines present in a given file,
21+
False to use boolean filters logic.
22+
2223
"""
2324

24-
use_filtering: bool = False
2525
path: str = "raw_data/multilang/downsample/filters/test.jsonl"
26-
use_pos_in_file_filtering: bool = False
26+
use_filtering: bool = False
2727
filters_to_include: List[str] = field(
2828
default_factory=lambda: ["is_vdo", "one_sentence_newline", "message_30_tokens", "diff_100_tokens"]
2929
)
3030
logic: str = "and"
3131
fit_filters: bool = True
32+
use_pos_in_file_filtering: bool = False
33+
use_subset: bool = False
34+
subset_num_examples: Optional[int] = None
3235

3336

3437
@dataclass

0 commit comments

Comments
 (0)