-
Notifications
You must be signed in to change notification settings - Fork 31
Refactor inference.py for LLM and RoBERTa support #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
6e5c0d2
5c0d9ec
d7730c8
819b147
13c0917
5895831
8b1d37e
238b05d
38fef7a
3b94a36
effb27b
600ba67
031abde
8657a11
0d042c6
3f13729
4e731d8
fd70377
7a5c9df
3ad7050
92d05ef
011ec33
5292128
17be9a7
6780770
0437c50
dfd6758
f7c458e
3434641
a543198
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| import argparse | ||
| import os | ||
| import torch | ||
|
|
||
| # ============================================================== | ||
| # Common utilities | ||
|
|
@@ -57,6 +59,8 @@ def aiu_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): | |
| def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False): | ||
| if local_rank < 0: | ||
| local_rank = rank | ||
|
|
||
| # FIXME: local_size not in use ? | ||
| if local_size < 0: | ||
| local_size = world_size | ||
|
|
||
|
|
@@ -67,3 +71,60 @@ def aiu_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False | |
| dprint(f"Detected running via torchrun") | ||
|
|
||
| aiu_setup(rank, world_size) | ||
|
|
||
|
|
||
| # ============================================================== | ||
| # Environment variables utilities | ||
| # ============================================================== | ||
| def set_aiu_env_vars(args: argparse.Namespace) -> None: | ||
| """Set necessary environment variables for AIU""" | ||
|
|
||
| _target_cache_size = max( | ||
| int(args.max_new_tokens * 2), | ||
| int(args.min_pad_length * 2.5), | ||
| int(args.fixed_prompt_length * 2.5), | ||
| ) | ||
| _prompt_size = max(int(args.min_pad_length), int(args.fixed_prompt_length)) | ||
| if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"): | ||
| if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit: | ||
| _prev = torch._dynamo.config.accumulated_cache_size_limit | ||
| torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size | ||
| dprint( | ||
| "NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit " | ||
| f"from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} " | ||
| f"to accomodate prompt size of {_prompt_size} and decode tokens of " | ||
| f"{args.max_new_tokens}" | ||
| ) | ||
|
|
||
| if _target_cache_size > torch._dynamo.config.cache_size_limit: | ||
| _prev = torch._dynamo.config.cache_size_limit | ||
| torch._dynamo.config.cache_size_limit = _target_cache_size | ||
| dprint( | ||
| f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to " | ||
| f"{torch._dynamo.config.cache_size_limit} to accomodate prompt size of " | ||
| f"{_prompt_size} and decode tokens of {args.max_new_tokens}" | ||
| ) | ||
|
||
|
|
||
| if not args.compile_dynamic: | ||
| torch._dynamo.config.assume_static_by_default = True | ||
| torch._dynamo.config.dynamic_shapes = False | ||
|
||
| torch._dynamo.config.automatic_dynamic_shapes = False | ||
|
|
||
| # This should be set outside!!! | ||
| os.environ.setdefault("SENCORES", "32") | ||
| os.environ.setdefault("SENCORELETS", "2") | ||
| os.environ.setdefault("DATA_PREC", "fp16") | ||
| os.environ.setdefault("FLEX_OVERWRITE_NMB_FRAME", "1") | ||
|
||
| os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "true") | ||
|
|
||
| os.environ.setdefault("COMPILATION_MODE", "offline_decoder") | ||
|
||
|
|
||
| if args.device_type == "aiu-senulator": | ||
| os.environ["FLEX_COMPUTE"] = "SENULATOR" | ||
| os.environ["FLEX_DEVICE"] = "MOCK" | ||
| else: | ||
| if "AIU_WORLD_RANK_0" not in os.environ: | ||
| print("must set AIU_WORLD_RANK_0") | ||
| exit() | ||
| os.environ.setdefault("FLEX_COMPUTE", "SENTIENT") | ||
| os.environ.setdefault("FLEX_DEVICE", "VFIO") | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's no longer needed? if you can't find any reference to it feel free to delete