Skip to content

Add Task model for generic tasks #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions multinet/api/migrations/0009_alter_upload_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 3.2.5 on 2021-09-01 16:27

from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
('api', '0008_alter_workspacerole_role'),
]

operations = [
migrations.AlterModelOptions(
name='upload',
options={},
),
]
12 changes: 10 additions & 2 deletions multinet/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from .network import Network
from .table import Table
from .upload import Upload
from .tasks import Task, Upload
from .workspace import Workspace, WorkspaceRole, WorkspaceRoleChoice

__all__ = ['Network', 'Table', 'Upload', 'Workspace', 'WorkspaceRole', 'WorkspaceRoleChoice']
__all__ = [
'Network',
'Table',
'Task',
'Upload',
'Workspace',
'WorkspaceRole',
'WorkspaceRoleChoice',
]
34 changes: 20 additions & 14 deletions multinet/api/models/upload.py → multinet/api/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,32 @@
from .workspace import Workspace


class Upload(TimeStampedModel):
"""A generic upload object."""
class Task(TimeStampedModel):
"""A generic task object."""

class DataType(models.TextChoices):
CSV = 'CSV'
D3_JSON = 'D3_JSON'
NESTED_JSON = 'NESTED_JSON'
NEWICK = 'NEWICK'
class Meta:
abstract = True

class UploadStatus(models.TextChoices):
class Status(models.TextChoices):
PENDING = 'PENDING'
STARTED = 'STARTED'
FAILED = 'FAILED'
FINISHED = 'FINISHED'

workspace = models.ForeignKey(Workspace, related_name='%(class)ss', on_delete=models.CASCADE)
user = models.ForeignKey(User, related_name='%(class)ss', null=True, on_delete=models.SET_NULL)
error_messages = ArrayField(models.CharField(max_length=500), null=True, blank=True)
status = models.CharField(max_length=10, choices=Status.choices, default=Status.PENDING)


class Upload(Task):
"""An object to track uploads."""

class DataType(models.TextChoices):
CSV = 'CSV'
D3_JSON = 'D3_JSON'
NESTED_JSON = 'NESTED_JSON'
NEWICK = 'NEWICK'

blob = S3FileField()
workspace = models.ForeignKey(Workspace, related_name='uploads', on_delete=models.CASCADE)
user = models.ForeignKey(User, related_name='uploads', null=True, on_delete=models.SET_NULL)
data_type = models.CharField(max_length=20, choices=DataType.choices)
error_messages = ArrayField(models.CharField(max_length=500), null=True, blank=True)
status = models.CharField(
max_length=10, choices=UploadStatus.choices, default=UploadStatus.PENDING
)
61 changes: 61 additions & 0 deletions multinet/api/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import celery
from celery.utils.log import get_task_logger

from multinet.api.models import Task

logger = get_task_logger(__name__)


class MultinetCeleryTask(celery.Task):
"""
A base class for multinet celery tasks.

This class should not be instantiated directly. Instead, task classes should inherit from this
class and override the `task_model` field with the desired model to associate to the tasks to.

NOTE: This task assumes that all arguments are passed using kwargs.
If an argument is passed positionally, this task will fail.
"""

task_model = None

@classmethod
def start_task(cls, task_id: int):
logger.info(f'Begin processing of {cls.task_model.__name__.lower()} {task_id}')
task: Task = cls.task_model.objects.get(id=task_id)
task.status = Task.Status.STARTED
task.save()

@staticmethod
def fail_task_with_message(task: Task, message: str):
task.status = Task.Status.FAILED
if task.error_messages is None:
task.error_messages = [message]
else:
task.error_messages.append(message)

task.save()

@staticmethod
def complete_task(task: Task):
task.status = Task.Status.FINISHED
task.save()

def __init__(self) -> None:
if self.task_model is None:
raise NotImplementedError('task_model cannot be None')

super().__init__()

def __call__(self, *args, **kwargs):
"""Wrap the inherited `__call__` method to set upload status."""
self.start_task(kwargs['task_id'])
return self.run(*args, **kwargs)

def on_failure(self, exc, celery_task_id, args, kwargs, einfo):
task: Task = self.task_model.objects.get(id=kwargs['task_id'])
self.fail_task_with_message(task, exc)

def on_success(self, retval, celery_task_id, args, kwargs):
task: Task = self.task_model.objects.get(id=kwargs['task_id'])
self.complete_task(task)
4 changes: 0 additions & 4 deletions multinet/api/tasks/process/__init__.py

This file was deleted.

5 changes: 5 additions & 0 deletions multinet/api/tasks/upload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .common import ProcessUploadTask
from .csv import process_csv
from .d3_json import process_d3_json

__all__ = ['ProcessUploadTask', 'process_csv', 'process_d3_json']
6 changes: 6 additions & 0 deletions multinet/api/tasks/upload/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from multinet.api.models import Upload
from multinet.api.tasks import MultinetCeleryTask


class ProcessUploadTask(MultinetCeleryTask):
task_model = Upload
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from multinet.api.models import Table, Upload

from .utils import ColumnTypeEnum, ProcessUploadTask, processor_dict
from .common import ProcessUploadTask
from .utils import ColumnTypeEnum, processor_dict

logger = get_task_logger(__name__)

Expand Down Expand Up @@ -36,9 +37,9 @@ def process_row(row: Dict[str, Any], cols: Dict[str, ColumnTypeEnum]) -> Dict:

@shared_task(base=ProcessUploadTask)
def process_csv(
upload_id: int, table_name: str, edge: bool, columns: Dict[str, ColumnTypeEnum]
task_id: int, table_name: str, edge: bool, columns: Dict[str, ColumnTypeEnum]
) -> None:
upload: Upload = Upload.objects.get(id=upload_id)
upload: Upload = Upload.objects.get(id=task_id)

# Download data from S3/MinIO
with upload.blob as blob_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from multinet.api.models import Network, Table, Upload

from .utils import ProcessUploadTask
from .common import ProcessUploadTask

logger = get_task_logger(__name__)

Expand All @@ -28,12 +28,12 @@ def d3_link_to_arango_doc(link: Dict, node_table_name: str) -> Dict:

@shared_task(base=ProcessUploadTask)
def process_d3_json(
upload_id: int,
task_id: int,
network_name: str,
node_table_name: str,
edge_table_name: str,
) -> None:
upload: Upload = Upload.objects.get(id=upload_id)
upload: Upload = Upload.objects.get(id=task_id)

# Download data from S3/MinIO
with upload.blob as blob_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json
from typing import Optional, Union

import celery
from celery.utils.log import get_task_logger
from dateutil import parser as dateutilparser

from multinet.api.models import Upload

logger = get_task_logger(__name__)


Expand All @@ -24,50 +21,6 @@ def values(cls):
return list(map(lambda c: c.value, cls))


class ProcessUploadTask(celery.Task):
"""
A celery task for upload processing.

NOTE: This task assumes that all arguments are passed using kwargs.
If an argument is passed positionally, this task will fail.
"""

@staticmethod
def start_upload(upload_id: int):
logger.info(f'Begin processing of upload {upload_id}')
upload: Upload = Upload.objects.get(id=upload_id)
upload.status = Upload.UploadStatus.STARTED
upload.save()

@staticmethod
def fail_upload_with_message(upload: Upload, message: str):
upload.status = Upload.UploadStatus.FAILED
if upload.error_messages is None:
upload.error_messages = [message]
else:
upload.error_messages.append(message)

upload.save()

@staticmethod
def complete_upload(upload: Upload):
upload.status = Upload.UploadStatus.FINISHED
upload.save()

def __call__(self, *args, **kwargs):
"""Wrap the inherited `__call__` method to set upload status."""
self.start_upload(kwargs['upload_id'])
return self.run(*args, **kwargs)

def on_failure(self, exc, task_id, args, kwargs, einfo):
upload: Upload = Upload.objects.get(id=kwargs['upload_id'])
self.fail_upload_with_message(upload, exc)

def on_success(self, retval, task_id, args, kwargs):
upload: Upload = Upload.objects.get(id=kwargs['upload_id'])
self.complete_upload(upload)


def str_to_bool(entry: str) -> bool:
"""Try to determine base format of boolean so it can be converted properly."""

Expand Down
2 changes: 1 addition & 1 deletion multinet/api/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import factory.fuzzy

from multinet.api.models import Network, Table, Workspace
from multinet.api.models.upload import Upload
from multinet.api.models.tasks import Upload


class UserFactory(factory.django.DjangoModelFactory):
Expand Down
2 changes: 1 addition & 1 deletion multinet/api/tests/test_upload_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient

from multinet.api.models.upload import Upload
from multinet.api.models.tasks import Upload
from multinet.api.models.workspace import Workspace, WorkspaceRoleChoice
from multinet.api.tests.factories import UploadFactory
from multinet.api.tests.fuzzy import TIMESTAMP_RE, workspace_re
Expand Down
8 changes: 4 additions & 4 deletions multinet/api/tests/test_upload_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import pytest
from rest_framework.response import Response

from multinet.api.models.upload import Upload
from multinet.api.models.tasks import Upload
from multinet.api.models.workspace import Workspace, WorkspaceRole, WorkspaceRoleChoice
from multinet.api.tasks.process.utils import str_to_number
from multinet.api.tasks.upload.utils import str_to_number
from multinet.api.tests.fuzzy import (
INTEGER_ID_RE,
TIMESTAMP_RE,
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_create_upload_model_csv(workspace: Workspace, user: User, airports_csv)
'user': user.username,
'data_type': Upload.DataType.CSV,
'error_messages': None,
'status': Upload.UploadStatus.PENDING,
'status': Upload.Status.PENDING,
'created': TIMESTAMP_RE,
'modified': TIMESTAMP_RE,
}
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_upload_valid_csv_task_response(

r_json = r.json()
assert r.status_code == 200
assert r_json['status'] == Upload.UploadStatus.FINISHED
assert r_json['status'] == Upload.Status.FINISHED
assert r_json['error_messages'] is None

# Check that table is created
Expand Down
6 changes: 3 additions & 3 deletions multinet/api/tests/test_upload_d3_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
WorkspaceRole,
WorkspaceRoleChoice,
)
from multinet.api.tasks.process.d3_json import d3_link_to_arango_doc, d3_node_to_arango_doc
from multinet.api.tasks.upload.d3_json import d3_link_to_arango_doc, d3_node_to_arango_doc
from multinet.api.tests.fuzzy import (
INTEGER_ID_RE,
TIMESTAMP_RE,
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_create_upload_model(workspace: Workspace, user: User, miserables_json):
'user': user.username,
'data_type': Upload.DataType.D3_JSON,
'error_messages': None,
'status': Upload.UploadStatus.PENDING,
'status': Upload.Status.PENDING,
'created': TIMESTAMP_RE,
'modified': TIMESTAMP_RE,
}
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_valid_d3_json_task_response(

r_json = r.json()
assert r.status_code == 200
assert r_json['status'] == Upload.UploadStatus.FINISHED
assert r_json['status'] == Upload.Status.FINISHED
assert r_json['error_messages'] is None

# Check that tables are created
Expand Down
2 changes: 1 addition & 1 deletion multinet/api/views/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rest_framework import serializers

from multinet.api.models import Network, Table, Upload, Workspace
from multinet.api.tasks.process.utils import ColumnTypeEnum
from multinet.api.tasks.upload.utils import ColumnTypeEnum


# The default ModelSerializer for User fails if the user already exists
Expand Down
6 changes: 3 additions & 3 deletions multinet/api/views/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from multinet.api.auth.decorators import require_workspace_permission
from multinet.api.models import Network, Table, Upload, Workspace, WorkspaceRoleChoice
from multinet.api.tasks.process import process_csv, process_d3_json
from multinet.api.tasks.upload import process_csv, process_d3_json

from .common import MultinetPagination, WorkspaceChildMixin
from .serializers import (
Expand Down Expand Up @@ -86,7 +86,7 @@ def upload_csv(self, request, parent_lookup_workspace__name: str):

# Dispatch task
process_csv.delay(
upload_id=upload.pk,
task_id=upload.pk,
table_name=table_name,
edge=serializer.validated_data['edge'],
columns=serializer.validated_data['columns'],
Expand Down Expand Up @@ -142,7 +142,7 @@ def upload_d3_json(self, request, parent_lookup_workspace__name: str):

# Dispatch task
process_d3_json.delay(
upload_id=upload.pk,
task_id=upload.pk,
network_name=network_name,
node_table_name=node_table_name,
edge_table_name=edge_table_name,
Expand Down