Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max

build-and-push-external-plugin-service-images:
build-and-push-flyteagent-images:
runs-on: ubuntu-latest
needs: deploy
steps:
Expand All @@ -161,12 +161,12 @@ jobs:
registry: ghcr.io
username: "${{ secrets.FLYTE_BOT_USERNAME }}"
password: "${{ secrets.FLYTE_BOT_PAT }}"
- name: Prepare External Plugin Service Image Names
id: external-plugin-service-names
- name: Prepare Fylte Agent Image Names
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Fylte

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks

id: flyteagent-names
uses: docker/metadata-action@v3
with:
images: |
ghcr.io/${{ github.repository_owner }}/external-plugin-service
ghcr.io/${{ github.repository_owner }}/flyteagent
tags: |
latest
${{ github.sha }}
Expand All @@ -177,10 +177,10 @@ jobs:
context: "."
platforms: linux/arm64, linux/amd64
push: ${{ github.event_name == 'release' }}
tags: ${{ steps.external-plugin-service-names.outputs.tags }}
tags: ${{ steps.flyteagent-names.outputs.tags }}
build-args: |
VERSION=${{ needs.deploy.outputs.version }}
file: ./Dockerfile.external-plugin-service
file: ./Dockerfile.agent
cache-from: type=gha
cache-to: type=gha,mode=max

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ flask==2.2.3
# via mlflow
flatbuffers==23.1.21
# via tensorflow
flyteidl==1.5.6
flyteidl==1.5.9
# via flytekit
fonttools==4.38.0
# via matplotlib
Expand Down
6 changes: 3 additions & 3 deletions flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import click
import grpc
from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server
from flyteidl.service.agent_service_pb2_grpc import add_AgentServiceServicer_to_server

from flytekit.extend.backend.external_plugin_service import BackendPluginServer
from flytekit.extend.backend.agent_service import AgentService

_serve_help = """Start a grpc server for the external plugin service."""

Expand Down Expand Up @@ -39,7 +39,7 @@ def serve(_: click.Context, port, worker, timeout):
"""
click.secho("Starting the external plugin service...", fg="blue")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker))
add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server)
add_AgentServiceServicer_to_server(AgentService(), server)

server.add_insecure_port(f"[::]:{port}")
server.start()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import grpc
from flyteidl.service.external_plugin_service_pb2 import (
from flyteidl.service.agent_service_pb2 import (
PERMANENT_FAILURE,
TaskCreateRequest,
TaskCreateResponse,
Expand All @@ -8,45 +8,45 @@
TaskGetRequest,
TaskGetResponse,
)
from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer
from flyteidl.service.agent_service_pb2_grpc import AgentServiceServicer

from flytekit import logger
from flytekit.extend.backend.base_plugin import BackendPluginRegistry
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


class BackendPluginServer(ExternalPluginServiceServicer):
class AgentService(AgentServiceServicer):
def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse:
try:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
plugin = BackendPluginRegistry.get_plugin(context, tmp.type)
if plugin is None:
agent = AgentRegistry.get_agent(context, tmp.type)
if agent is None:
return TaskCreateResponse()
return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse:
try:
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return TaskGetResponse(state=PERMANENT_FAILURE)
return plugin.get(context=context, job_id=request.job_id)
return agent.get(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse:
try:
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
agent = AgentRegistry.get_agent(context, request.task_type)
if agent is None:
return TaskDeleteResponse()
return plugin.delete(context=context, job_id=request.job_id)
return agent.delete(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import grpc
from flyteidl.core.tasks_pb2 import TaskTemplate
from flyteidl.service.external_plugin_service_pb2 import (
from flyteidl.service.agent_service_pb2 import (
RETRYABLE_FAILURE,
RUNNING,
SUCCEEDED,
Expand All @@ -17,15 +17,15 @@
from flytekit.models.literals import LiteralMap


class BackendPluginBase(ABC):
class AgentBase(ABC):
"""
This is the base class for all backend plugins. It defines the interface that all plugins must implement.
The external plugins service will be run either locally or in a pod, and will be responsible for
invoking backend plugins. The propeller will communicate with the external plugins service
This is the base class for all agents. It defines the interface that all agents must implement.
The agent service will be run either locally or in a pod, and will be responsible for
invoking agents. The propeller will communicate with the agent service
to create tasks, get the status of tasks, and delete tasks.

All the backend plugins should be registered in the BackendPluginRegistry. External plugins service
will look up the plugin based on the task type. Every task type can only have one plugin.
All the agents should be registered in the AgentRegistry. Agent Service
will look up the agent based on the task type. Every task type can only have one agent.
"""

def __init__(self, task_type: str):
Expand All @@ -34,7 +34,7 @@ def __init__(self, task_type: str):
@property
def task_type(self) -> str:
"""
task_type is the name of the task type that this plugin supports.
task_type is the name of the task type that this agent supports.
"""
return self._task_type

Expand Down Expand Up @@ -68,34 +68,34 @@ def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteRespon
pass


class BackendPluginRegistry(object):
class AgentRegistry(object):
"""
This is the registry for all backend plugins. The external plugins service will look up the plugin
This is the registry for all agents. The agent service will look up the agent
based on the task type.
"""

_REGISTRY: typing.Dict[str, BackendPluginBase] = {}
_REGISTRY: typing.Dict[str, AgentBase] = {}

@staticmethod
def register(plugin: BackendPluginBase):
if plugin.task_type in BackendPluginRegistry._REGISTRY:
raise ValueError(f"Duplicate plugin for task type {plugin.task_type}")
BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin
logger.info(f"Registering backend plugin for task type {plugin.task_type}")
def register(agent: AgentBase):
if agent.task_type in AgentRegistry._REGISTRY:
raise ValueError(f"Duplicate agent for task type {agent.task_type}")
AgentRegistry._REGISTRY[agent.task_type] = agent
logger.info(f"Registering an agent for task type {agent.task_type}")

@staticmethod
def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]:
if task_type not in BackendPluginRegistry._REGISTRY:
logger.error(f"Cannot find backend plugin for task type [{task_type}]")
def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
logger.error(f"Cannot find agent for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find backend plugin for task type [{task_type}]")
context.set_details(f"Cannot find the agent for task type [{task_type}]")
return None
return BackendPluginRegistry._REGISTRY[task_type]
return AgentRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
"""
Convert the state from the backend plugin to the state in flyte.
Convert the state from the agent to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

BigQueryConfig
BigQueryTask
BigQueryAgent
"""

from .backend_plugin import BigQueryPlugin
from .agent import BigQueryAgent
from .task import BigQueryConfig, BigQueryTask
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,12 @@
from typing import Dict, Optional

import grpc
from flyteidl.service.external_plugin_service_pb2 import (
SUCCEEDED,
TaskCreateResponse,
TaskDeleteResponse,
TaskGetResponse,
)
from flyteidl.service.agent_service_pb2 import SUCCEEDED, TaskCreateResponse, TaskDeleteResponse, TaskGetResponse
from google.cloud import bigquery

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry, convert_to_flyte_state
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
Expand All @@ -30,7 +25,7 @@
}


class BigQueryPlugin(BackendPluginBase):
class BigQueryAgent(AgentBase):
def __init__(self):
super().__init__(task_type="bigquery_query_job_task")

Expand Down Expand Up @@ -91,4 +86,4 @@ def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteRespon
return TaskDeleteResponse()


BackendPluginRegistry.register(BigQueryPlugin())
AgentRegistry.register(BigQueryAgent())
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from unittest.mock import MagicMock

import grpc
from flyteidl.service.external_plugin_service_pb2 import SUCCEEDED
from flyteidl.service.agent_service_pb2 import SUCCEEDED

import flytekit.models.interface as interface_models
from flytekit.extend.backend.base_plugin import BackendPluginRegistry
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import ResourceType
Expand All @@ -15,7 +15,7 @@

@mock.patch("google.cloud.bigquery.job.QueryJob")
@mock.patch("google.cloud.bigquery.Client")
def test_bigquery_plugin(mock_client, mock_query_job):
def test_bigquery_agent(mock_client, mock_query_job):
job_id = "dummy_id"
mock_instance = mock_client.return_value
mock_query_job_instance = mock_query_job.return_value
Expand All @@ -39,7 +39,7 @@ def __init__(self):
mock_instance.cancel_job.return_value = MockJob()

ctx = MagicMock(spec=grpc.ServicerContext)
p = BackendPluginRegistry.get_plugin(ctx, "bigquery_query_job_task")
p = AgentRegistry.get_agent(ctx, "bigquery_query_job_task")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.5.6",
"flyteidl==1.5.9",
Comment thread
pingsutw marked this conversation as resolved.
Outdated
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest.mock import MagicMock

import grpc
from flyteidl.service.external_plugin_service_pb2 import (
from flyteidl.service.agent_service_pb2 import (
PERMANENT_FAILURE,
SUCCEEDED,
TaskCreateRequest,
Expand All @@ -15,8 +15,8 @@
)

import flytekit.models.interface as interface_models
from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry
from flytekit.extend.backend.external_plugin_service import BackendPluginServer
from flytekit.extend.backend.agent_service import AgentService
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import Identifier, ResourceType
from flytekit.models.literals import LiteralMap
Expand All @@ -25,7 +25,7 @@
dummy_id = "dummy_id"


class DummyPlugin(BackendPluginBase):
class DummyAgent(AgentBase):
def __init__(self):
super().__init__(task_type="dummy")

Expand All @@ -45,7 +45,7 @@ def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse:
return TaskDeleteResponse()


BackendPluginRegistry.register(DummyPlugin())
AgentRegistry.register(DummyAgent())

task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version")
task_metadata = task.TaskMetadata(
Expand Down Expand Up @@ -82,24 +82,24 @@ def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse:
)


def test_dummy_plugin():
def test_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
p = BackendPluginRegistry.get_plugin(ctx, "dummy")
p = AgentRegistry.get_agent(ctx, "dummy")
assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == dummy_id
assert p.get(ctx, dummy_id).state == SUCCEEDED
assert p.delete(ctx, dummy_id) == TaskDeleteResponse()


def test_backend_plugin_server():
server = BackendPluginServer()
def test_agent_server():
service = AgentService()
ctx = MagicMock(spec=grpc.ServicerContext)
request = TaskCreateRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl()
)

assert server.CreateTask(request, ctx).job_id == dummy_id
assert server.GetTask(TaskGetRequest(task_type="dummy", job_id=dummy_id), ctx).state == SUCCEEDED
assert server.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse()
assert service.CreateTask(request, ctx).job_id == dummy_id
assert service.GetTask(TaskGetRequest(task_type="dummy", job_id=dummy_id), ctx).state == SUCCEEDED
assert service.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse()

res = server.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx)
res = service.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx)
assert res.state == PERMANENT_FAILURE