Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ LLADA_BASE_PATH=GSAI-ML/LLaDA-8B-Base
LLADA_INST_PATH=GSAI-ML/LLaDA-8B-Instruct
LLADA_1_5_PATH=GSAI-ML/LLaDA-1.5
DREAM_BASE_PATH=Dream-org/Dream-v0-Base-7B
DREAM_INST_PATH=Dream-org/Dream-v0-Instruct-7B
DREAM_INST_PATH=Dream-org/Dream-v0-Instruct-7B
SDAR_8B_CHAT_PATH=JetLM/SDAR-8B-Chat
8 changes: 8 additions & 0 deletions configs/gen_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ def get_generation_args(task: str, model: str, cache: str | None = None):
match model:
case "dream-base" | "dream-inst":
top_p = 0.9
case model if model.startswith("sdar"):
# SDAR block diffusion defaults (see SDAR repo `generate.py`)
block_length = 4
# keep `steps=gen_length` so that per-block denoising steps can be derived as:
# denoising_steps = steps // (gen_length // block_length) == block_length
temperature = 1.0
top_p = 0.95
top_k = 50

return GenerationArgs(
gen_length=gen_length,
Expand Down
22 changes: 22 additions & 0 deletions configs/generation/sdar.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SDAR block-diffusion decoding.
# Most defaults (gen_length/block_length/steps/top_p, etc.) are filled by `configs/gen_args.py`
strategy: sdar

# SDAR remasking / unmasking strategies (from SDAR repo `generate.py`)
remasking_strategy: low_confidence_dynamic
confidence_threshold: 0.85
eb_threshold: 0.35

alg: "maskgit_plus"
gen_length: null
block_length: null
steps: null
temperature: 1.0
top_p: 0.95
top_k: 50

# Stop when the first EOS is generated; remaining masks (if any) are replaced with EOS.
stop_until_eot: true

output_probs: false

10 changes: 10 additions & 0 deletions configs/model/sdar-8b-chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults:
- sdar-common
- _self_

name: sdar-8b-chat
path: ${oc.env:SDAR_8B_CHAT_PATH}

# SDAR is a chat model; let lm-eval apply the chat template.
apply_chat_template: true

7 changes: 7 additions & 0 deletions configs/model/sdar-common.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
generation:
# JetLM/SDAR-8B-Chat config.json:
# bos_token_id/eos_token_id: 151643, mask_token_id: 151669
mask_token_id: 151669
eot_token_id: 151643
pad_token_id: 151643

4 changes: 3 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def main(cfg: DictConfig) -> None:
use_cache=(
os.path.join(output_dir, "response") if cfg.use_eval_cache else None
),
apply_chat_template=cfg.model.name.endswith("inst"),
apply_chat_template=cfg.model.get(
"apply_chat_template", cfg.model.name.endswith("inst")
),
**overwrite_eval_task(cfg),
)

Expand Down
Loading