Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create model for async AQL tasks #79

Merged
merged 14 commits into from
Sep 14, 2021
85 changes: 85 additions & 0 deletions multinet/api/migrations/0010_aqlquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Generated by Django 3.2.6 on 2021-09-13 13:31

from django.conf import settings
import django.contrib.postgres.fields
from django.db import migrations, models
import django.db.models.deletion
import django_extensions.db.fields


class Migration(migrations.Migration):

dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('api', '0009_alter_upload_options'),
]

operations = [
migrations.CreateModel(
name='AqlQuery',
fields=[
(
'id',
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name='ID'
),
),
(
'created',
django_extensions.db.fields.CreationDateTimeField(
auto_now_add=True, verbose_name='created'
),
),
(
'modified',
django_extensions.db.fields.ModificationDateTimeField(
auto_now=True, verbose_name='modified'
),
),
(
'error_messages',
django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(max_length=500),
blank=True,
null=True,
size=None,
),
),
(
'status',
models.CharField(
choices=[
('PENDING', 'Pending'),
('STARTED', 'Started'),
('FAILED', 'Failed'),
('FINISHED', 'Finished'),
],
default='PENDING',
max_length=10,
),
),
('query', models.TextField()),
('results', models.JSONField(blank=True, null=True)),
(
'user',
models.ForeignKey(
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name='aqlquerys',
to=settings.AUTH_USER_MODEL,
),
),
(
'workspace',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name='aqlquerys',
to='api.workspace',
),
),
],
options={
'abstract': False,
},
),
]
3 changes: 2 additions & 1 deletion multinet/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .network import Network
from .table import Table
from .tasks import Task, Upload
from .tasks import AqlQuery, Task, Upload
from .workspace import Workspace, WorkspaceRole, WorkspaceRoleChoice

__all__ = [
'AqlQuery',
'Network',
'Table',
'Task',
Expand Down
7 changes: 7 additions & 0 deletions multinet/api/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ class DataType(models.TextChoices):

blob = S3FileField()
data_type = models.CharField(max_length=20, choices=DataType.choices)


class AqlQuery(Task):
"""An object to track AQL queries."""

query = models.TextField()
results = models.JSONField(blank=True, null=True)
3 changes: 3 additions & 0 deletions multinet/api/tasks/aql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .aql_query import AqlQueryTask, execute_query

__all__ = ['AqlQueryTask', 'execute_query']
27 changes: 27 additions & 0 deletions multinet/api/tasks/aql/aql_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from arango.cursor import Cursor

# from arango.exceptions import AQLQueryExecuteError, ArangoServerError
from celery import shared_task

from multinet.api.models import AqlQuery, Workspace
from multinet.api.tasks import MultinetCeleryTask
from multinet.api.utils.arango import ArangoQuery


class AqlQueryTask(MultinetCeleryTask):
task_model = AqlQuery


@shared_task(base=AqlQueryTask)
def execute_query(task_id: int) -> None:
query_task: AqlQuery = AqlQuery.objects.select_related('workspace').get(id=task_id)
workspace: Workspace = query_task.workspace
query_str = query_task.query
# Run the query on Arango DB
database = workspace.get_arango_db()
query = ArangoQuery(database, query_str, time_limit_secs=60)
cursor: Cursor = query.execute()

# Store the results on the task object
query_task.results = list(cursor)
query_task.save()
184 changes: 184 additions & 0 deletions multinet/api/tests/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from django.contrib.auth.models import User
from faker import Faker
import pytest
from rest_framework.response import Response
from rest_framework.test import APIClient

from multinet.api.models.tasks import AqlQuery
from multinet.api.models.workspace import Workspace, WorkspaceRole, WorkspaceRoleChoice
from multinet.api.tests.conftest import populated_table
from multinet.api.tests.fuzzy import INTEGER_ID_RE, TIMESTAMP_RE, workspace_re


@pytest.fixture
def valid_query(workspace: Workspace, user: User, authenticated_api_client: APIClient):
"""Create a fixture representing the response of a POST request for AQL queries."""
workspace.set_user_permission(user, WorkspaceRoleChoice.READER)
node_table = populated_table(workspace, False)
query_str = f'FOR document IN {node_table.name} RETURN document'
r: Response = authenticated_api_client.post(
f'/api/workspaces/{workspace.name}/queries/', {'query': query_str}, format='json'
)
WorkspaceRole.objects.filter(workspace=workspace, user=user).delete()
return {'response': r, 'query': query_str, 'nodes': list(node_table.get_rows())}


@pytest.fixture
def mutating_query(workspace: Workspace, user: User, authenticated_api_client: APIClient):
"""Create a fixture for a mutating AQL query that will have an error message post processing."""
workspace.set_user_permission(user, WorkspaceRoleChoice.READER)
node_table = populated_table(workspace, False)
fake = Faker()
query_str = f"INSERT {{ 'name': {fake.pystr()} }} INTO {node_table.name}"
r: Response = authenticated_api_client.post(
f'/api/workspaces/{workspace.name}/queries/', {'query': query_str}, format='json'
)
WorkspaceRole.objects.filter(workspace=workspace, user=user).delete()
return {'response': r, 'query': query_str, 'nodes': list(node_table.get_rows())}


@pytest.mark.django_db
def test_query_rest_create(workspace: Workspace, user: User, valid_query):
r = valid_query['response']
assert r.status_code == 200
assert r.json() == {
'id': INTEGER_ID_RE,
'workspace': workspace_re(workspace),
'query': valid_query['query'],
'user': user.username,
'error_messages': None,
'status': AqlQuery.Status.PENDING,
'created': TIMESTAMP_RE,
'modified': TIMESTAMP_RE,
}


@pytest.mark.django_db
def test_query_rest_create_mutating(workspace: Workspace, user: User, mutating_query):
r = mutating_query['response']

# even though the query is not read-only, the task object should be created
assert r.status_code == 200
assert r.json() == {
'id': INTEGER_ID_RE,
'workspace': workspace_re(workspace),
'query': mutating_query['query'],
'user': user.username,
'error_messages': None,
'status': AqlQuery.Status.PENDING,
'created': TIMESTAMP_RE,
'modified': TIMESTAMP_RE,
}


@pytest.mark.django_db
@pytest.mark.parametrize(
'permission,is_owner,status_code,success',
[
(None, False, 404, False),
(WorkspaceRoleChoice.READER, False, 200, True),
(WorkspaceRoleChoice.WRITER, False, 200, True),
(WorkspaceRoleChoice.MAINTAINER, False, 200, True),
(None, True, 200, True),
],
)
def test_query_rest_retrieve(
workspace: Workspace,
user: User,
authenticated_api_client: APIClient,
valid_query,
permission: WorkspaceRoleChoice,
is_owner: bool,
status_code: int,
success: bool,
):
if permission is not None:
workspace.set_user_permission(user, permission)
elif is_owner:
workspace.set_owner(user)

query_info = valid_query['response'].json()
query_id = query_info['id']
r: Response = authenticated_api_client.get(
f'/api/workspaces/{workspace.name}/queries/{query_id}/'
)
assert r.status_code == status_code
if success:
r_json = r.json()
assert r_json['status'] == AqlQuery.Status.FINISHED


@pytest.mark.django_db
def test_query_rest_retrieve_mutating(
workspace: Workspace, user: User, authenticated_api_client: APIClient, mutating_query
):
workspace.set_user_permission(user, WorkspaceRoleChoice.READER)

query_info = mutating_query['response'].json()
query_id = query_info['id']
r: Response = authenticated_api_client.get(
f'/api/workspaces/{workspace.name}/queries/{query_id}/'
)
assert r.status_code == 200
r_json = r.json()
assert len(r_json['error_messages']) > 0
assert r_json['status'] == AqlQuery.Status.FAILED


@pytest.mark.django_db
@pytest.mark.parametrize(
'permission,is_owner,status_code,success',
[
(None, False, 404, False),
(WorkspaceRoleChoice.READER, False, 200, True),
(WorkspaceRoleChoice.WRITER, False, 200, True),
(WorkspaceRoleChoice.MAINTAINER, False, 200, True),
(None, True, 200, True),
],
)
def test_query_rest_retrieve_results(
workspace: Workspace,
user: User,
authenticated_api_client: APIClient,
valid_query,
permission: WorkspaceRoleChoice,
is_owner: bool,
status_code: int,
success: bool,
):
if permission is not None:
workspace.set_user_permission(user, permission)
elif is_owner:
workspace.set_owner(user)

query_info = valid_query['response'].json()
query_id = query_info['id']
r: Response = authenticated_api_client.get(
f'/api/workspaces/{workspace.name}/queries/{query_id}/results/'
)
assert r.status_code == status_code
if success:
r_json = r.json()
assert r_json['id'] == query_id
assert r_json['workspace'] == str(workspace)
assert r_json['user'] == str(user)

results = r_json['results']
expected_results = valid_query['nodes']
assert len(results) == len(expected_results)
for row in results:
assert row in expected_results


@pytest.mark.django_db
def test_query_rest_retrieve_results_mutating(
workspace: Workspace, user: User, authenticated_api_client: APIClient, mutating_query
):
workspace.set_user_permission(user, WorkspaceRoleChoice.READER)
query_info = mutating_query['response'].json()
query_id = query_info['id']
r: Response = authenticated_api_client.get(
f'/api/workspaces/{workspace.name}/queries/{query_id}/results/'
)
assert r.status_code == 400
assert r.data == 'The given query could not be executed, and has no results'
2 changes: 2 additions & 0 deletions multinet/api/views/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .network import NetworkViewSet
from .query import AqlQueryViewSet
from .table import TableViewSet
from .upload import UploadViewSet
from .users import users_me_view, users_search_view
Expand All @@ -11,4 +12,5 @@
'TableViewSet',
'UploadViewSet',
'WorkspaceViewSet',
'AqlQueryViewSet',
]
58 changes: 58 additions & 0 deletions multinet/api/views/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from django.shortcuts import get_object_or_404
from drf_yasg.utils import swagger_auto_schema
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticatedOrReadOnly
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet

from multinet.api.auth.decorators import require_workspace_permission
from multinet.api.models import AqlQuery, Workspace, WorkspaceRoleChoice
from multinet.api.tasks.aql import execute_query

from .common import WorkspaceChildMixin
from .serializers import AqlQueryResultsSerializer, AqlQuerySerializer, AqlQueryTaskSerializer


class AqlQueryViewSet(WorkspaceChildMixin, ReadOnlyModelViewSet):
queryset = AqlQuery.objects.all().select_related('workspace')
permission_classes = [IsAuthenticatedOrReadOnly]
serializer_class = AqlQueryTaskSerializer
swagger_tags = ['queries']

@swagger_auto_schema(
request_body=AqlQuerySerializer(), responses={200: AqlQueryTaskSerializer()}
)
@require_workspace_permission(WorkspaceRoleChoice.READER)
def create(self, request, parent_lookup_workspace__name: str):
"""Create an AQL query task."""
workspace: Workspace = get_object_or_404(Workspace, name=parent_lookup_workspace__name)
serializer = AqlQuerySerializer(data=request.data)
serializer.is_valid(raise_exception=True)
query_str = serializer.validated_data['query']

query: AqlQuery = AqlQuery.objects.create(
workspace=workspace, user=request.user, query=query_str
)

execute_query.delay(task_id=query.pk)

return Response(AqlQueryTaskSerializer(query).data, status=status.HTTP_200_OK)

@swagger_auto_schema(responses={200: AqlQueryResultsSerializer()})
@action(detail=True, url_path='results')
@require_workspace_permission(WorkspaceRoleChoice.READER)
def results(self, request, parent_lookup_workspace__name: str, pk):
workspace: Workspace = get_object_or_404(Workspace, name=parent_lookup_workspace__name)
aql_task: AqlQuery = get_object_or_404(AqlQuery, workspace=workspace, pk=pk)
if aql_task.status == AqlQuery.Status.FINISHED:
return Response(AqlQueryResultsSerializer(aql_task).data, status=status.HTTP_200_OK)
elif aql_task.status in [AqlQuery.Status.STARTED, AqlQuery.Status.PENDING]:
return Response(
'The given query has not finished executing', status=status.HTTP_400_BAD_REQUEST
)
elif aql_task.status == AqlQuery.Status.FAILED:
return Response(
'The given query could not be executed, and has no results',
status=status.HTTP_400_BAD_REQUEST,
)
Loading