Skip to content

Commit

Permalink
api export done
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Feb 4, 2025
1 parent 0ec30f7 commit d2c6821
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
8 changes: 4 additions & 4 deletions optimum/exporters/neuron/model_configs/traced_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
T5EncoderForSeq2SeqLMWrapper,
T5EncoderWrapper,
UnetNeuronWrapper,
CLIPVisionWithProjectionNeuronWrapper,
CLIPVisionModelNeuronWrapper,
)


Expand Down Expand Up @@ -246,10 +246,10 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
VISION_CONFIG = "vision_config"


@register_in_tasks_manager("clip-vision-with-projection", *["feature-extraction"], library_name="diffusers")
class CLIPVisionWithProjectionModelNeuronConfig(VisionNeuronConfig):
@register_in_tasks_manager("clip-vision-model", *["feature-extraction"], library_name="diffusers")
class CLIPVisionModelNeuronConfig(VisionNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
CUSTOM_MODEL_WRAPPER = CLIPVisionWithProjectionNeuronWrapper
CUSTOM_MODEL_WRAPPER = CLIPVisionModelNeuronWrapper

@property
def inputs(self) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def forward(self, input_ids, attention_mask):
return out_tuple["token_embeddings"], out_tuple["sentence_embedding"]


class CLIPVisionWithProjectionNeuronWrapper(torch.nn.Module):
class CLIPVisionModelNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
self.model = model
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_diffusion_models_for_export(
model=image_encoder,
exporter="neuron",
task="feature-extraction",
model_type="clip-vision-with-projection",
model_type="clip-vision-model",
library_name=library_name,
)
image_encoder_neuron_config = image_encoder_config_constructor(
Expand Down
8 changes: 7 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def load_model(
"vae_encoder": vae_encoder_path,
"vae_decoder": vae_decoder_path,
"controlnet": controlnet_paths,
"image_encoder": image_encoder_path,
}

def _load_models_to_neuron(submodels, models_on_both_cores=None, models_on_a_single_core=None):
Expand Down Expand Up @@ -530,7 +531,7 @@ def _load_models_to_neuron(submodels, models_on_both_cores=None, models_on_a_sin

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
check_if_weights_replacable(self.configs, weights)
model_names = ["text_encoder", "text_encoder_2", "unet", "transformer", "vae_decoder", "vae_encoder"]
model_names = ["text_encoder", "text_encoder_2", "unet", "transformer", "vae_decoder", "vae_encoder", "image_encoder"]
for name in model_names:
model = getattr(self, name, None)
weight = getattr(weights, name, None)
Expand Down Expand Up @@ -565,6 +566,7 @@ def _save_pretrained(
vae_encoder_file_name: str = NEURON_FILE_NAME,
vae_decoder_file_name: str = NEURON_FILE_NAME,
controlnet_file_name: str = NEURON_FILE_NAME,
image_encoder_file_name: str = NEURON_FILE_NAME,
):
"""
Saves the model to the serialized format optimized for Neuron devices.
Expand All @@ -589,6 +591,7 @@ def _remove_submodel_if_non_exist(model_names):
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_TRANSFORMER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
DIFFUSION_MODEL_IMAGE_ENCODER_NAME,
]
)

Expand Down Expand Up @@ -617,6 +620,9 @@ def _remove_submodel_if_non_exist(model_names):
DIFFUSION_MODEL_VAE_DECODER_NAME: save_directory
/ DIFFUSION_MODEL_VAE_DECODER_NAME
/ vae_decoder_file_name,
DIFFUSION_MODEL_IMAGE_ENCODER_NAME: save_directory
/ DIFFUSION_MODEL_IMAGE_ENCODER_NAME
/ image_encoder_file_name,
}
dst_paths[DIFFUSION_MODEL_CONTROLNET_NAME] = [
save_directory / (DIFFUSION_MODEL_CONTROLNET_NAME + f"_{str(idx)}") / controlnet_file_name
Expand Down

0 comments on commit d2c6821

Please sign in to comment.