diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 6f5a9dc86..16a56e219 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -321,7 +321,8 @@ class UNetNeuronConfig(VisionNeuronConfig): MODEL_TYPE = "unet" CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( - image_size="sample_size", + height="height", + width="width", num_channels="in_channels", hidden_size="cross_attention_dim", vocab_size="norm_num_groups", @@ -353,14 +354,6 @@ def outputs(self) -> List[str]: return ["sample"] def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): - # For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`. - # TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now) - if self.height == self.width: - self._normalized_config.image_size = self.height - else: - raise ValueError( - "You need to input the same value for `self.height({self.height})` and `self.width({self.width})`." - ) dummy_inputs = super().generate_dummy_inputs(**kwargs) dummy_inputs["timestep"] = dummy_inputs["timestep"].float() dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] @@ -395,7 +388,6 @@ class VaeEncoderNeuronConfig(VisionNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( num_channels="in_channels", - image_size="sample_size", allow_new=True, ) @@ -408,14 +400,6 @@ def outputs(self) -> List[str]: return ["latent_sample"] def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): - # For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`. - # TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now) - if self.height == self.width: - self._normalized_config.image_size = self.height - else: - raise ValueError( - "You need to input the same value for `self.height({self.height})` and `self.width({self.width})`." - ) dummy_inputs = super().generate_dummy_inputs(**kwargs) if return_tuple is True: