From d82df49224cc22f237b27941321b0bae1ec0e56b Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Mon, 24 Jun 2024 11:01:34 -0400 Subject: [PATCH 1/6] Update Gemma on Dataflow example to pass max_length as arg --- dataflow/gemma/custom_model_gemma.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index fbf0b975057..ae8628658d3 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -77,10 +77,17 @@ def run_inference( # Loop each text string, and use a tuple to store the inference results. predictions = [] for one_text in batch: - result = model.generate(one_text, max_length=64) + result = model.generate(one_text, **inference_args) predictions.append(result) return utils._convert_to_result(batch, predictions, self._model_name) + def validate_inference_args(self, inference_args: Optional[dict[str, + Any]]): + if len(inference_args) > 1 or "max_length" not in inference_args.keys: + raise ValueError( + "invalid inference args, only valid arg is max_length, got", + inference_args) + class FormatOutput(beam.DoFn): def process(self, element, *args, **kwargs): @@ -123,8 +130,10 @@ def process(self, element, *args, **kwargs): beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) | "RunInference-Gemma" >> RunInference( - GemmaModelHandler(args.model_path) - ) # Send the prompts to the model and get responses. + GemmaModelHandler(args.model_path), + inference_args={ + "max_length", 64 + }) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. | "Publish Result" >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)) From 0b115e9ae21fc44768c35cedba571cb6ac8e0670 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Mon, 24 Jun 2024 11:03:47 -0400 Subject: [PATCH 2/6] fix default, change passed arg --- dataflow/gemma/custom_model_gemma.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index ae8628658d3..14b592a1efc 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -74,6 +74,8 @@ def run_inference( Returns: An Iterable of type PredictionResult. """ + if inference_args is None: + inference_args = {"max_length": 64} # Loop each text string, and use a tuple to store the inference results. predictions = [] for one_text in batch: @@ -132,7 +134,7 @@ def process(self, element, *args, **kwargs): | "RunInference-Gemma" >> RunInference( GemmaModelHandler(args.model_path), inference_args={ - "max_length", 64 + "max_length": 32 }) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. | "Publish Result" >> From 6c5f1571e0f8c909cb588f1d3e9da93995a564ea Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Mon, 24 Jun 2024 11:48:38 -0400 Subject: [PATCH 3/6] fix fn invocation --- dataflow/gemma/custom_model_gemma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 14b592a1efc..04ad032450c 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -85,7 +85,8 @@ def run_inference( def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - if len(inference_args) > 1 or "max_length" not in inference_args.keys: + if len(inference_args) > 1 or "max_length" not in inference_args.keys( + ): raise ValueError( "invalid inference args, only valid arg is max_length, got", inference_args) From 8cf46e8fcce2cea9c8d771eb24045995f53365b1 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Jun 2024 10:51:46 -0400 Subject: [PATCH 4/6] update requirements.txt --- dataflow/gemma/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataflow/gemma/requirements.txt b/dataflow/gemma/requirements.txt index 76fc60632ee..8d205dfc541 100644 --- a/dataflow/gemma/requirements.txt +++ b/dataflow/gemma/requirements.txt @@ -1,4 +1,5 @@ apache_beam[gcp]==2.54.0 protobuf==4.25.0 keras_nlp==0.8.2 -keras==3.0.5 \ No newline at end of file +keras==3.0.5 +tensorflow==2.16.1 \ No newline at end of file From da1ff5d686cbbb4eacc17b7837461047ef47b2e2 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Jun 2024 16:10:18 -0400 Subject: [PATCH 5/6] Fix input arg validation --- dataflow/gemma/custom_model_gemma.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 04ad032450c..b90e7021e73 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -85,11 +85,12 @@ def run_inference( def validate_inference_args(self, inference_args: Optional[dict[str, Any]]): - if len(inference_args) > 1 or "max_length" not in inference_args.keys( - ): - raise ValueError( - "invalid inference args, only valid arg is max_length, got", - inference_args) + if inference_args: + if len(inference_args + ) > 1 or "max_length" not in inference_args.keys(): + raise ValueError( + "invalid inference args, only valid arg is max_length, got", + inference_args) class FormatOutput(beam.DoFn): @@ -133,10 +134,9 @@ def process(self, element, *args, **kwargs): beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) | "RunInference-Gemma" >> RunInference( - GemmaModelHandler(args.model_path), - inference_args={ - "max_length": 32 - }) # Send the prompts to the model and get responses. + GemmaModelHandler(args.model_path), + inference_args={"max_length": 32} + ) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. | "Publish Result" >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)) From ffaf3ef2e96ed524544890e126461a08fe327a59 Mon Sep 17 00:00:00 2001 From: Jack McCluskey Date: Tue, 25 Jun 2024 16:23:20 -0400 Subject: [PATCH 6/6] linting, formatting --- dataflow/gemma/custom_model_gemma.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index b90e7021e73..7b7168f1e57 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -134,9 +134,10 @@ def process(self, element, *args, **kwargs): beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) | "RunInference-Gemma" >> RunInference( - GemmaModelHandler(args.model_path), - inference_args={"max_length": 32} - ) # Send the prompts to the model and get responses. + GemmaModelHandler(args.model_path), + inference_args={ + "max_length": 32 + }) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. | "Publish Result" >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))