Skip to content
This repository was archived by the owner on Nov 12, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
fbb5a30
This pull request sets up the initial GitHub repository configuration…
EO-Robotics Sep 11, 2025
735ee5a
Refactor model input handling for multimodal data, including image an…
DelinQu Sep 12, 2025
1390b23
Refactor model input handling for multimodal data, including image an…
DelinQu Sep 12, 2025
5e1fef4
Merge main into eo1-dev: keep .github/settings.yml deleted as intende…
DelinQu Sep 12, 2025
76f0223
Merge branch 'eo1-dev' of https://github.com/EO-Robotics/EO-1 into eo…
DelinQu Sep 12, 2025
abd4aa9
Refactor training scripts and configuration files for improved clarit…
DelinQu Sep 12, 2025
efd0527
Update .gitignore to include new output paths, modify dataset configu…
DelinQu Sep 14, 2025
b241ddd
Refactor import order in test_vlm.py for improved readability and con…
DelinQu Sep 14, 2025
5ae0a74
Merge branch 'main' into eo1-dev
DelinQu Sep 14, 2025
27d1222
Update pre-commit configuration to exclude processing_eo1.py from ban…
DelinQu Sep 14, 2025
b5d359b
Merge remote-tracking branch 'origin/main' into eo1-dev
DelinQu Sep 15, 2025
7e46c90
Refactor EO1 configuration and processing classes for improved struct…
DelinQu Sep 20, 2025
d2ef83b
Update .gitignore to exclude hf_save_pretrained.py and enhance README…
DelinQu Sep 20, 2025
177234c
Merge remote-tracking branch 'origin/main' into eo1-dev
DelinQu Sep 20, 2025
440eb92
Refactor training scripts to remove env.sh sourcing and activate cond…
DelinQu Sep 22, 2025
82096b5
Fix unified generation in modeling_eo1.py.
DelinQu Sep 24, 2025
3594979
Update .gitignore to include 'eo/model_dev' directory, ensuring prope…
DelinQu Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -237,5 +237,8 @@ demo_data/demos25

demo_data/libero_spatial_no_noops_1.0.0_lerobot
experiments/test

tools/hf_save_pretrained.py
tools/hf_save_pretrained.py
dev/
eo/model_dev
20 changes: 12 additions & 8 deletions eo/model/modeling_eo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
states=states,
)
else:
outputs = self.vlm_backbone.model(
Expand Down Expand Up @@ -434,21 +435,25 @@ def sample_actions(

# pass prefix, update kvcache
seq_len = input_ids.shape[-1]
chunk_size = self.config.action_chunk_size
suffix_len = -1 # exclude <|action_end|>
prefix_len = seq_len - self.config.action_chunk_size - 1
prefix_len = seq_len - chunk_size - 1

cache_seq_len = attention_mask.shape[-1]
cache_prefix_len = cache_seq_len - chunk_size - 1

outputs = self.vlm_backbone.model(
position_ids=position_ids[..., :prefix_len],
attention_mask=attention_mask[:, :prefix_len],
attention_mask=attention_mask[:, :cache_prefix_len],
past_key_values=past_key_values,
inputs_embeds=inputs_embeds[:, :prefix_len],
use_cache=True,
cache_position=cache_position[:-prefix_len] if cache_position is not None else None,
cache_position=cache_position[:prefix_len] if cache_position is not None else None,
)

# denoising
device = states.device
actions_shape = (states.shape[0], self.config.action_chunk_size, self.config.max_action_dim)
actions_shape = (states.shape[0], chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)

x_t = noise.type(self.action_in_proj.weight.dtype)
Expand All @@ -461,7 +466,7 @@ def sample_actions(
action_time_embs = self.embed_suffix(time, x_t)
inputs_embeds[action_mask] = action_time_embs.to(inputs_embeds.dtype)

past_key_values.crop(prefix_len)
past_key_values.crop(cache_prefix_len)

outputs = self.vlm_backbone.model(
position_ids=position_ids[..., prefix_len:suffix_len],
Expand All @@ -471,7 +476,7 @@ def sample_actions(
use_cache=True,
cache_position=cache_position[prefix_len:suffix_len] if cache_position is not None else None,
)
action_time_embs = outputs.last_hidden_state[:, : self.config.action_chunk_size]
action_time_embs = outputs.last_hidden_state[:, :chunk_size]
action_time_embs = action_time_embs.type(self.action_out_proj.dtype)
v_t = self.action_out_proj(action_time_embs)

Expand All @@ -480,8 +485,7 @@ def sample_actions(

# last step
if time < -dt * 3 / 2:
suffix_len = seq_len

suffix_len = cache_seq_len
outputs.last_hidden_state = torch.cat([past_hidden_state, outputs.last_hidden_state], dim=1)
return x_t, outputs

Expand Down
165 changes: 3 additions & 162 deletions eo/model/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,14 @@

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.generation.utils import (
GenerateNonBeamOutput,
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList,
)
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
Expand Down Expand Up @@ -1499,9 +1493,10 @@ def prepare_inputs_for_generation(
text_positions = model_inputs["position_ids"][None, ...]
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)

if cache_position[0] != 0:
if cache_position[0] != 0 and cache_position.shape[-1] == 1:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
model_inputs["states"] = None

return model_inputs

Expand Down Expand Up @@ -1651,160 +1646,6 @@ def _expand_dict_for_generation(dict_to_expand):

return input_ids, model_kwargs

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"], # noqa: F821
**model_kwargs,
) -> GenerateNonBeamOutput | torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)

model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if (
generation_config.compile_config is not None
and generation_config.compile_config.fullgraph
):
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)

if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update(
{"output_hidden_states": output_hidden_states} if output_hidden_states else {}
)

if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue

# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)

# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (outputs.attentions,)
if output_hidden_states:
decoder_hidden_states += (outputs.hidden_states,)
actions = outputs.get("actions", None)

# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1

del outputs

if streamer is not None:
streamer.end()

if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
actions=actions,
)
else:
return input_ids


# custom model output
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor
scores: tuple[torch.FloatTensor] | None = None
logits: tuple[torch.FloatTensor] | None = None
attentions: tuple[tuple[torch.FloatTensor]] | None = None
hidden_states: tuple[tuple[torch.FloatTensor]] | None = None
past_key_values: tuple[tuple[tuple[torch.FloatTensor]]] | None = None
actions: torch.FloatTensor | None = None


__all__ = [
"Qwen2_5_VLForConditionalGeneration",
Expand Down
3 changes: 2 additions & 1 deletion eo/train/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class TrainPipelineConfig(TrainingArguments):
freeze_vision_tower: bool = field(default=False)
freeze_llm: bool = field(default=False)
freeze_merger: bool = field(default=False)
freeze_lm_head: bool = field(default=False)
attn_implementation: str = field(default="sdpa") # sdpa, flash_attention_2, flash_attention_3

lora_enable: bool = False
Expand Down Expand Up @@ -97,7 +98,7 @@ def __post_init__(self):
self.freeze_llm = True
warnings.warn("`freeze_llm` is set to True when `lora_enable`.", stacklevel=2)

if not self.lora_enable:
if not self.lora_enable and self.vision_lora:
self.vision_lora = False
warnings.warn("`vision_lora` is set to False when `lora_enable` is False.", stacklevel=2)

Expand Down
2 changes: 1 addition & 1 deletion eo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def configure_vision_tower(vlm, training_args, compute_dtype, device):
def configure_llm(vlm, training_args):
"""Configure the LLM."""
lm_head = vlm.lm_head.parameters()
set_requires_grad(lm_head, not training_args.freeze_llm)
set_requires_grad(lm_head, not training_args.freeze_lm_head)

llm_params = vlm.model.parameters()
set_requires_grad(llm_params, not training_args.freeze_llm)
Expand Down
2 changes: 1 addition & 1 deletion experiments/1_demo/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ epoch=50
model_name_or_path=
run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE}

. scripts/env.sh

conda activate eo

accelerate launch $ACCELERATE_ARGS scripts/train.py \
Expand Down
2 changes: 1 addition & 1 deletion experiments/2_libero/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ epoch=50
model_name_or_path=
run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE}

. scripts/env.sh

conda activate eo

accelerate launch $ACCELERATE_ARGS scripts/train.py \
Expand Down
2 changes: 1 addition & 1 deletion experiments/3_simpler/simpler_env/eval_simpler.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
. scripts/env.sh


dist_tasks=(
bridge.sh
Expand Down
2 changes: 1 addition & 1 deletion experiments/3_simpler/train_bridge.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ epoch=20
model_name_or_path=
run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE}

. scripts/env.sh

conda activate eo

accelerate launch $ACCELERATE_ARGS scripts/train.py \
Expand Down
2 changes: 1 addition & 1 deletion experiments/3_simpler/train_fractal.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ epoch=10
model_name_or_path=
run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE}

. scripts/env.sh

conda activate eo

accelerate launch $ACCELERATE_ARGS scripts/train.py \
Expand Down
20 changes: 18 additions & 2 deletions scripts/test_vlm.py → tests/test_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,34 @@

times = 0
past_key_values = None
past_pixel_values_n = 0
past_grid_thw_n = 0

while True:
if times > 0:
prompt = input("Enter your prompt: ")
if prompt == "q":
exit(0)
messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})

messages.append(
{
"role": "user",
"content": [
{"type": "image", "image": "demo_data/refcoco/images/COCO_train2014_000000580957_2.jpg"},
{"type": "text", "text": prompt},
],
}
)
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to("cuda")

if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"][past_pixel_values_n:]
inputs["image_grid_thw"] = inputs["image_grid_thw"][past_grid_thw_n:]

past_pixel_values_n += inputs["pixel_values"].shape[0]
past_grid_thw_n += inputs["image_grid_thw"].shape[0]

input_length = inputs["input_ids"].shape[1]
outputs = model.generate(
**inputs, max_new_tokens=1024, past_key_values=past_key_values, return_dict_in_generate=True
Expand Down
Loading