Skip to content

Commit a298159

Browse files
committed
Fix Big Query feature
1 parent f817224 commit a298159

File tree

7 files changed

+28
-7
lines changed

7 files changed

+28
-7
lines changed

benchmarks/benchmark_db_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def write_run(
141141

142142
from benchmark_db_writer import bq_writer_utils
143143
from benchmark_db_writer import dataclass_bigquery_writer
144-
from benchmark_db_writer.run_summary_writer import sample_run_summary_writer
144+
from benchmark_db_writer.run_summary_writer import run_summary_writer
145145
from benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema
146146

147147
def get_db_client(
@@ -168,9 +168,9 @@ def get_db_client(
168168
print(options.model_id)
169169

170170
if (
171-
sample_run_summary_writer.validate_model_id(options.model_id, options.is_test)
172-
and sample_run_summary_writer.validate_hardware_id(options.hardware_id, options.is_test)
173-
and sample_run_summary_writer.validate_software_id(options.software_id, options.is_test)
171+
run_summary_writer.validate_model_id(options.model_id, options.is_test)
172+
and run_summary_writer.validate_hardware_id(options.hardware_id, options.is_test)
173+
and run_summary_writer.validate_software_id(options.software_id, options.is_test)
174174
):
175175
summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
176176
run_id=f"run-{uuid.uuid4()}",
@@ -179,6 +179,7 @@ def get_db_client(
179179
hardware_id=options.hardware_id,
180180
hardware_num_chips=number_of_chips,
181181
hardware_num_nodes=number_of_nodes,
182+
hardware_num_slices=options.hardware_num_slices,
182183
result_success=run_success,
183184
configs_framework=framework_config_in_json,
184185
configs_env=env_variables,

benchmarks/globals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os.path
1818

1919
# This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath`
20-
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "MaxText")
20+
MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", "src/MaxText")
2121

2222
# This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc.
2323
MAXTEXT_REPO_ROOT = os.environ.get(

benchmarks/maxtext_xpk_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def __post_init__(self):
158158
else:
159159
self.num_devices_per_slice = int(self.device_type.split("-")[1]) / 2
160160
self.topology = ""
161+
self.hardware_id = self.device_type.split("-")[0]
162+
if self.hardware_id == "v5litepod":
163+
self.hardware_id = "v5e"
161164

162165

163166
def wait_for_xpk_workload_completion(cluster_config: XpkClusterConfig, workload_name, xpk_path) -> int:
@@ -341,6 +344,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
341344
"model_id": wl_config.model.model_type,
342345
"hardware_id": wl_config.hardware_id,
343346
"software_id": "jax_maxtext",
347+
"hardware_num_slices": wl_config.num_slices,
344348
"number_of_chips": wl_config.num_devices_per_slice * wl_config.num_slices,
345349
"container_image_name": wl_config.base_docker_image,
346350
"global_batch_size": per_device_batch_size * wl_config.num_devices_per_slice * wl_config.num_slices,
@@ -445,7 +449,8 @@ def build_user_command(
445449
f"base_output_directory={wl_config.base_output_directory}",
446450
f"{vertex_tensorboard}",
447451
f"{run_name_command}",
448-
f"{enable_metrics_cmd}" f"{upload_hlo_dump}",
452+
f"{enable_metrics_cmd}",
453+
f"{upload_hlo_dump}",
449454
]
450455
)
451456
return command

benchmarks/recipes/runner_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def generate_and_run_workloads(user_config, num_slices_list, num_steps, priority
6565
xpk_path=user_config.xpk_path,
6666
num_steps=num_steps,
6767
priority=priority,
68+
generate_metrics_and_upload_to_big_query=user_config.bq_enable,
69+
db_project=user_config.bq_db_project,
70+
db_dataset=user_config.bq_db_dataset,
6871
)
6972

7073
# Generate XPK command

benchmarks/recipes/user_configs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ class UserConfig:
7070
selected_model_names: list[str] = dataclasses.field(default_factory=lambda: ["llama3_1_8b_8192"])
7171
num_slices_list: list[int] = dataclasses.field(default_factory=lambda: [2])
7272

73+
# BigQuery configuration
74+
bq_enable: bool = False
75+
bq_db_project: str = ""
76+
bq_db_dataset: str = ""
77+
7378
# other configuration
7479
xpk_path: str = "~/xpk"
7580
max_restarts: int = 0

benchmarks/upload_metrics_to_bq.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ def add_parser_arguments(parser: argparse.ArgumentParser):
186186
default=True,
187187
help="Whether to use the testing project or production project",
188188
)
189+
parser.add_argument(
190+
"--hardware_num_slices",
191+
type=int,
192+
required=False,
193+
help="hardware slice number",
194+
)
189195

190196

191197
def download_metrics_file_locally(metrics_gcs_file: str, local_file: str) -> int:

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
absl-py
22
aqtp
33
array-record
4+
benchmark_db_writer@git+https://github.com/CIeNET-International/aotc.git@c0bef62eac87c99152ff2e9fd48da1f7d9f3cc04#subdirectory=src/aotc/benchmark_db_writer
45
cloud-accelerator-diagnostics
56
cloud-tpu-diagnostics
67
datasets
7-
flax
8+
flax==0.11.1
89
gcsfs
910
google-api-python-client
1011
google-cloud-aiplatform

0 commit comments

Comments
 (0)