Skip to content

Update Gemma on Dataflow example to pass max_length as arg #11897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions dataflow/gemma/custom_model_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,24 @@ def run_inference(
Returns:
An Iterable of type PredictionResult.
"""
if inference_args is None:
inference_args = {"max_length": 64}
# Loop each text string, and use a tuple to store the inference results.
predictions = []
for one_text in batch:
result = model.generate(one_text, max_length=64)
result = model.generate(one_text, **inference_args)
predictions.append(result)
return utils._convert_to_result(batch, predictions, self._model_name)

def validate_inference_args(self, inference_args: Optional[dict[str,
Any]]):
if inference_args:
if len(inference_args
) > 1 or "max_length" not in inference_args.keys():
raise ValueError(
"invalid inference args, only valid arg is max_length, got",
inference_args)


class FormatOutput(beam.DoFn):
def process(self, element, *args, **kwargs):
Expand Down Expand Up @@ -123,8 +134,10 @@ def process(self, element, *args, **kwargs):
beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Parse" >> beam.Map(lambda x: x.decode("utf-8"))
| "RunInference-Gemma" >> RunInference(
GemmaModelHandler(args.model_path)
) # Send the prompts to the model and get responses.
GemmaModelHandler(args.model_path),
inference_args={
"max_length": 32
}) # Send the prompts to the model and get responses.
| "Format Output" >> beam.ParDo(FormatOutput()) # Format the output.
| "Publish Result" >>
beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))
Expand Down
3 changes: 2 additions & 1 deletion dataflow/gemma/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
apache_beam[gcp]==2.54.0
protobuf==4.25.0
keras_nlp==0.8.2
keras==3.0.5
keras==3.0.5
tensorflow==2.16.1