@@ -78,6 +78,7 @@ class PathwaysConfig:
7878 server_flags : str = ''
7979 proxy_flags : str = ''
8080 worker_flags : str = ''
81+ headless : bool = False
8182
8283
8384# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
@@ -446,7 +447,7 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
446447
447448 # Get proxy and xla flag string from model config
448449 proxy_flags_string = pw_config .proxy_flags
449- xla_flags_string = wl_config .model .xla_flags
450+ xla_flags_string = wl_config .model .xla_flags if not pw_config . headless else ''
450451
451452 # Split both proxy_flags_string and xla_flags_string into lists of flags
452453 proxy_flags_list = proxy_flags_string .strip ().split ()
@@ -457,8 +458,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
457458
458459 # Remove the flags that are specified to be removed.
459460 if (
460- wl_config .model .pathways_xla_flag_options
461- and xla_flags .REMOVE in wl_config .model .pathways_xla_flag_options
461+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
462+ and xla_flags .REMOVE in wl_config .model .pathways_xla_flag_options )
462463 ):
463464 flags_to_remove = wl_config .model .pathways_xla_flag_options [
464465 xla_flags .REMOVE
@@ -471,8 +472,8 @@ def _get_pathways_proxy_flags(wl_config: WorkloadConfig):
471472
472473 # Add the flags that are specified to be added.
473474 if (
474- wl_config .model .pathways_xla_flag_options
475- and xla_flags .ADD_PROXY in wl_config .model .pathways_xla_flag_options
475+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
476+ and xla_flags .ADD_PROXY in wl_config .model .pathways_xla_flag_options )
476477 ):
477478 flags_to_add = wl_config .model .pathways_xla_flag_options [
478479 xla_flags .ADD_PROXY
@@ -500,8 +501,8 @@ def _get_pathways_worker_flags(wl_config: WorkloadConfig):
500501
501502 # Add the flags that are specified to be added.
502503 if (
503- wl_config .model .pathways_xla_flag_options
504- and xla_flags .ADD_WORKER in wl_config .model .pathways_xla_flag_options
504+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
505+ and xla_flags .ADD_WORKER in wl_config .model .pathways_xla_flag_options )
505506 ):
506507 flags_to_add = wl_config .model .pathways_xla_flag_options [
507508 xla_flags .ADD_WORKER
@@ -523,8 +524,8 @@ def _get_pathways_server_flags(wl_config: WorkloadConfig):
523524
524525 # Add the flags that are specified to be added.
525526 if (
526- wl_config .model .pathways_xla_flag_options
527- and xla_flags .ADD_SERVER in wl_config .model .pathways_xla_flag_options
527+ not pw_config . headless and ( wl_config .model .pathways_xla_flag_options
528+ and xla_flags .ADD_SERVER in wl_config .model .pathways_xla_flag_options )
528529 ):
529530 flags_to_add = wl_config .model .pathways_xla_flag_options [
530531 xla_flags .ADD_SERVER
@@ -569,6 +570,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig):
569570 f' --custom-pathways-server-args="{ server_flags } " '
570571 f' --custom-pathways-proxy-server-args="{ proxy_flags } " '
571572 f' --custom-pathways-worker-args="{ worker_flags } " '
573+ f' { "--headless" if pw_config .headless else "" } '
572574 )
573575 return pathways_specific_flags
574576
@@ -582,6 +584,7 @@ def generate_xpk_workload_cmd(
582584 """Generates a command to run a maxtext model on XPK."""
583585
584586 is_pathways_enabled = wl_config .pathways_config is not None
587+ is_pathways_headless_enabled = wl_config .pathways_config and wl_config .pathways_config .headless
585588
586589 time .localtime ()
587590 length_of_random_str = 3
@@ -614,10 +617,12 @@ def generate_xpk_workload_cmd(
614617 wl_config .run_name ,
615618 'metrics' )
616619
617- user_command = build_user_command (
618- name = name ,
619- wl_config = wl_config
620- )
620+ user_command = ''
621+ if not is_pathways_headless_enabled :
622+ user_command = build_user_command (
623+ name = name ,
624+ wl_config = wl_config
625+ )
621626
622627 additional_flags = ''
623628 if not is_pathways_enabled and wl_config .libtpu_type == LibTpuType .CUSTOM :
@@ -641,7 +646,7 @@ def generate_xpk_workload_cmd(
641646 docker_image_flag = f'--docker-image="{ wl_config .base_docker_image } "'
642647
643648 upload_metrics_to_bq_cmd = ""
644- if wl_config .generate_metrics_and_upload_to_big_query :
649+ if wl_config .generate_metrics_and_upload_to_big_query and not is_pathways_headless_enabled :
645650 # TODO (optionally) make it so that this upload step is done on local device instead of within the workload.
646651 args = _build_args_from_config (wl_config )
647652 args_str = ""
0 commit comments