diff --git a/.gitignore b/.gitignore index 8ba59e4..7ce5191 100644 --- a/.gitignore +++ b/.gitignore @@ -237,5 +237,8 @@ demo_data/demos25 demo_data/libero_spatial_no_noops_1.0.0_lerobot experiments/test + +tools/hf_save_pretrained.py tools/hf_save_pretrained.py dev/ +eo/model_dev diff --git a/eo/model/modeling_eo1.py b/eo/model/modeling_eo1.py index 062d144..3ed810b 100644 --- a/eo/model/modeling_eo1.py +++ b/eo/model/modeling_eo1.py @@ -338,6 +338,7 @@ def forward( past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, + states=states, ) else: outputs = self.vlm_backbone.model( @@ -434,21 +435,25 @@ def sample_actions( # pass prefix, update kvcache seq_len = input_ids.shape[-1] + chunk_size = self.config.action_chunk_size suffix_len = -1 # exclude <|action_end|> - prefix_len = seq_len - self.config.action_chunk_size - 1 + prefix_len = seq_len - chunk_size - 1 + + cache_seq_len = attention_mask.shape[-1] + cache_prefix_len = cache_seq_len - chunk_size - 1 outputs = self.vlm_backbone.model( position_ids=position_ids[..., :prefix_len], - attention_mask=attention_mask[:, :prefix_len], + attention_mask=attention_mask[:, :cache_prefix_len], past_key_values=past_key_values, inputs_embeds=inputs_embeds[:, :prefix_len], use_cache=True, - cache_position=cache_position[:-prefix_len] if cache_position is not None else None, + cache_position=cache_position[:prefix_len] if cache_position is not None else None, ) # denoising device = states.device - actions_shape = (states.shape[0], self.config.action_chunk_size, self.config.max_action_dim) + actions_shape = (states.shape[0], chunk_size, self.config.max_action_dim) noise = self.sample_noise(actions_shape, device) x_t = noise.type(self.action_in_proj.weight.dtype) @@ -461,7 +466,7 @@ def sample_actions( action_time_embs = self.embed_suffix(time, x_t) inputs_embeds[action_mask] = action_time_embs.to(inputs_embeds.dtype) - past_key_values.crop(prefix_len) + past_key_values.crop(cache_prefix_len) outputs = self.vlm_backbone.model( position_ids=position_ids[..., prefix_len:suffix_len], @@ -471,7 +476,7 @@ def sample_actions( use_cache=True, cache_position=cache_position[prefix_len:suffix_len] if cache_position is not None else None, ) - action_time_embs = outputs.last_hidden_state[:, : self.config.action_chunk_size] + action_time_embs = outputs.last_hidden_state[:, :chunk_size] action_time_embs = action_time_embs.type(self.action_out_proj.dtype) v_t = self.action_out_proj(action_time_embs) @@ -480,8 +485,7 @@ def sample_actions( # last step if time < -dt * 3 / 2: - suffix_len = seq_len - + suffix_len = cache_seq_len outputs.last_hidden_state = torch.cat([past_hidden_state, outputs.last_hidden_state], dim=1) return x_t, outputs diff --git a/eo/model/modeling_qwen2_5_vl.py b/eo/model/modeling_qwen2_5_vl.py index 9223b72..73febd4 100644 --- a/eo/model/modeling_qwen2_5_vl.py +++ b/eo/model/modeling_qwen2_5_vl.py @@ -26,7 +26,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn @@ -34,12 +34,6 @@ from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin -from transformers.generation.utils import ( - GenerateNonBeamOutput, - GenerationConfig, - LogitsProcessorList, - StoppingCriteriaList, -) from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer @@ -1499,9 +1493,10 @@ def prepare_inputs_for_generation( text_positions = model_inputs["position_ids"][None, ...] model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0) - if cache_position[0] != 0: + if cache_position[0] != 0 and cache_position.shape[-1] == 1: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None + model_inputs["states"] = None return model_inputs @@ -1651,160 +1646,6 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs - def _sample( - self, - input_ids: torch.LongTensor, - logits_processor: LogitsProcessorList, - stopping_criteria: StoppingCriteriaList, - generation_config: GenerationConfig, - synced_gpus: bool, - streamer: Optional["BaseStreamer"], # noqa: F821 - **model_kwargs, - ) -> GenerateNonBeamOutput | torch.LongTensor: - pad_token_id = generation_config._pad_token_tensor - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - do_sample = generation_config.do_sample - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # keep track of which sequences are already finished - batch_size, cur_len = input_ids.shape[:2] - this_peer_finished = False - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) - - model_forward = self.__call__ - compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config) - if compile_forward: - import os - - os.environ["TOKENIZERS_PARALLELISM"] = "0" - # If we use FA2 and a static cache, we cannot compile with fullgraph - if self.config._attn_implementation == "flash_attention_2": - # only raise warning if the user passed an explicit compile-config - if ( - generation_config.compile_config is not None - and generation_config.compile_config.fullgraph - ): - logger.warning_once( - "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as " - "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`." - ) - generation_config.compile_config.fullgraph = False - model_forward = self.get_compiled_call(generation_config.compile_config) - - if generation_config.prefill_chunk_size is not None: - model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) - is_prefill = False - else: - is_prefill = True - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # prepare variable output controls (note: some models won't accept all output controls) - model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) - model_inputs.update( - {"output_hidden_states": output_hidden_states} if output_hidden_states else {} - ) - - if is_prefill: - outputs = self(**model_inputs, return_dict=True) - is_prefill = False - else: - outputs = model_forward(**model_inputs, return_dict=True) - - # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if synced_gpus and this_peer_finished: - continue - - # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].to( - copy=True, dtype=torch.float32, device=input_ids.device - ) - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += (outputs.attentions,) - if output_hidden_states: - decoder_hidden_states += (outputs.hidden_states,) - actions = outputs.get("actions", None) - - # token selection - if do_sample: - probs = nn.functional.softmax(next_token_scores, dim=-1) - # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(next_token_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - cur_len += 1 - - del outputs - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - actions=actions, - ) - else: - return input_ids - - -# custom model output -@dataclass -class GenerateDecoderOnlyOutput(ModelOutput): - sequences: torch.LongTensor - scores: tuple[torch.FloatTensor] | None = None - logits: tuple[torch.FloatTensor] | None = None - attentions: tuple[tuple[torch.FloatTensor]] | None = None - hidden_states: tuple[tuple[torch.FloatTensor]] | None = None - past_key_values: tuple[tuple[tuple[torch.FloatTensor]]] | None = None - actions: torch.FloatTensor | None = None - __all__ = [ "Qwen2_5_VLForConditionalGeneration", diff --git a/eo/train/pipeline_config.py b/eo/train/pipeline_config.py index badccc2..4fad576 100644 --- a/eo/train/pipeline_config.py +++ b/eo/train/pipeline_config.py @@ -64,6 +64,7 @@ class TrainPipelineConfig(TrainingArguments): freeze_vision_tower: bool = field(default=False) freeze_llm: bool = field(default=False) freeze_merger: bool = field(default=False) + freeze_lm_head: bool = field(default=False) attn_implementation: str = field(default="sdpa") # sdpa, flash_attention_2, flash_attention_3 lora_enable: bool = False @@ -97,7 +98,7 @@ def __post_init__(self): self.freeze_llm = True warnings.warn("`freeze_llm` is set to True when `lora_enable`.", stacklevel=2) - if not self.lora_enable: + if not self.lora_enable and self.vision_lora: self.vision_lora = False warnings.warn("`vision_lora` is set to False when `lora_enable` is False.", stacklevel=2) diff --git a/eo/train/train_utils.py b/eo/train/train_utils.py index 848c57d..dc31e2e 100644 --- a/eo/train/train_utils.py +++ b/eo/train/train_utils.py @@ -28,7 +28,7 @@ def configure_vision_tower(vlm, training_args, compute_dtype, device): def configure_llm(vlm, training_args): """Configure the LLM.""" lm_head = vlm.lm_head.parameters() - set_requires_grad(lm_head, not training_args.freeze_llm) + set_requires_grad(lm_head, not training_args.freeze_lm_head) llm_params = vlm.model.parameters() set_requires_grad(llm_params, not training_args.freeze_llm) diff --git a/experiments/1_demo/train.sh b/experiments/1_demo/train.sh index 86e71b4..162f43f 100644 --- a/experiments/1_demo/train.sh +++ b/experiments/1_demo/train.sh @@ -18,7 +18,7 @@ epoch=50 model_name_or_path= run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE} -. scripts/env.sh + conda activate eo accelerate launch $ACCELERATE_ARGS scripts/train.py \ diff --git a/experiments/2_libero/train.sh b/experiments/2_libero/train.sh index 9df9f3f..96af8b8 100644 --- a/experiments/2_libero/train.sh +++ b/experiments/2_libero/train.sh @@ -18,7 +18,7 @@ epoch=50 model_name_or_path= run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE} -. scripts/env.sh + conda activate eo accelerate launch $ACCELERATE_ARGS scripts/train.py \ diff --git a/experiments/3_simpler/simpler_env/eval_simpler.sh b/experiments/3_simpler/simpler_env/eval_simpler.sh index 5408ed6..778390b 100644 --- a/experiments/3_simpler/simpler_env/eval_simpler.sh +++ b/experiments/3_simpler/simpler_env/eval_simpler.sh @@ -1,4 +1,4 @@ -. scripts/env.sh + dist_tasks=( bridge.sh diff --git a/experiments/3_simpler/train_bridge.sh b/experiments/3_simpler/train_bridge.sh index 8e97e15..9644f1e 100644 --- a/experiments/3_simpler/train_bridge.sh +++ b/experiments/3_simpler/train_bridge.sh @@ -17,7 +17,7 @@ epoch=20 model_name_or_path= run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE} -. scripts/env.sh + conda activate eo accelerate launch $ACCELERATE_ARGS scripts/train.py \ diff --git a/experiments/3_simpler/train_fractal.sh b/experiments/3_simpler/train_fractal.sh index 09b1a84..20963cc 100644 --- a/experiments/3_simpler/train_fractal.sh +++ b/experiments/3_simpler/train_fractal.sh @@ -17,7 +17,7 @@ epoch=10 model_name_or_path= run_name=${dataset_name}_ck${chunk_size}_gpu${GPUS}_lr${lr}_vlr${vlr}_mlr${mlr}_bs${PER_DEVICE_BATCH_SIZE} -. scripts/env.sh + conda activate eo accelerate launch $ACCELERATE_ARGS scripts/train.py \ diff --git a/scripts/test_vlm.py b/tests/test_vlm.py similarity index 74% rename from scripts/test_vlm.py rename to tests/test_vlm.py index 69b843d..df43ede 100644 --- a/scripts/test_vlm.py +++ b/tests/test_vlm.py @@ -27,18 +27,34 @@ times = 0 past_key_values = None +past_pixel_values_n = 0 +past_grid_thw_n = 0 while True: if times > 0: prompt = input("Enter your prompt: ") if prompt == "q": exit(0) - messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]}) - + messages.append( + { + "role": "user", + "content": [ + {"type": "image", "image": "demo_data/refcoco/images/COCO_train2014_000000580957_2.jpg"}, + {"type": "text", "text": prompt}, + ], + } + ) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to("cuda") + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"][past_pixel_values_n:] + inputs["image_grid_thw"] = inputs["image_grid_thw"][past_grid_thw_n:] + + past_pixel_values_n += inputs["pixel_values"].shape[0] + past_grid_thw_n += inputs["image_grid_thw"].shape[0] + input_length = inputs["input_ids"].shape[1] outputs = model.generate( **inputs, max_new_tokens=1024, past_key_values=past_key_values, return_dict_in_generate=True diff --git a/tools/test_hf_model.py b/tools/test_hf_model.py deleted file mode 100644 index d6bb120..0000000 --- a/tools/test_hf_model.py +++ /dev/null @@ -1,54 +0,0 @@ -from transformers import AutoProcessor - -from eo.model.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration - -"""set model name or path""" -model_name_or_path = "../pretrained/Qwen2.5-VL-3B-Instruct" # or EO-3B -model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - model_name_or_path, - device_map="auto", - trust_remote_code=True, - # attn_implementation="flash_attention_2", -) - -processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True) - -messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": "demo_data/refcoco/images/COCO_train2014_000000168643_2.jpg"}, - { - "type": "text", - "text": "If the yellow robot gripper follows the yellow trajectory, what will happen? Choices: A. Robot puts the soda on the wooden steps. B. Robot moves the soda in front of the wooden steps. C. Robot moves the soda to the very top of the wooden steps. D. Robot picks up the soda can and moves it up. Please answer directly with only the letter of the correct option and nothing else.", - }, - ], - }, -] - -times = 0 -past_key_values = None - -while True: - if times > 0: - prompt = input("Enter your prompt: ") - if prompt == "q": - exit(0) - messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]}) - inputs = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ).to("cuda") - - input_length = inputs["input_ids"].shape[1] - outputs = model.generate( - **inputs, max_new_tokens=1024, past_key_values=past_key_values, return_dict_in_generate=True - ) - - past_key_values = outputs.past_key_values - generated_ids = outputs.sequences - - completion = processor.decode(generated_ids[0, input_length:], skip_special_tokens=False) - print(completion) - - messages.append({"role": "assistant", "content": [{"type": "text", "text": completion}]}) - times += 1