Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Expand serialization_specs recursively #44

Merged
merged 4 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions serialization_spec/serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Prefetch
from rest_framework.utils import model_meta
from rest_framework.fields import Field
Expand All @@ -6,6 +7,7 @@

from typing import List, Dict, Union
from collections import OrderedDict
import copy

"""
Parse a serialization spec such as:
Expand Down Expand Up @@ -201,18 +203,36 @@ def prefetch_related(request_user, queryset, model, prefixes, serialization_spec
return queryset


def get_serialization_spec(view_or_plugin, request_user=None):
if hasattr(view_or_plugin, 'get_serialization_spec'):
view_or_plugin.request_user = request_user
return view_or_plugin.get_serialization_spec()
return getattr(view_or_plugin, 'serialization_spec', None)


def expand_nested_specs(serialization_spec, request_user):
def get_serialization_spec(serialization_spec_plugin):
if hasattr(serialization_spec_plugin, 'get_serialization_spec'):
serialization_spec_plugin.request_user = request_user
return serialization_spec_plugin.get_serialization_spec()
return getattr(serialization_spec_plugin, 'serialization_spec', [])
expanded_serialization_spec = []

for each in serialization_spec:
if not isinstance(each, dict):
expanded_serialization_spec.append(each)
else:
expanded_dict = {}
for key, childspec in each.items():
if isinstance(childspec, SerializationSpecPlugin):
serialization_spec = get_serialization_spec(childspec, request_user)
if serialization_spec is not None:
plugin_copy = copy.deepcopy(childspec)
plugin_copy.serialization_spec = expand_nested_specs(plugin_copy.serialization_spec, request_user)
expanded_serialization_spec += plugin_copy.serialization_spec
expanded_dict[key] = plugin_copy
else:
expanded_dict[key] = childspec
else:
expanded_dict[key] = expand_nested_specs(childspec, request_user)
expanded_serialization_spec.append(expanded_dict)

return serialization_spec + sum([
get_serialization_spec(childspec)
for each in serialization_spec if isinstance(each, dict)
for key, childspec in each.items() if isinstance(childspec, SerializationSpecPlugin)
], [])
return expanded_serialization_spec


class NormalisedSpec:
Expand Down Expand Up @@ -278,8 +298,9 @@ def get_object(self):

def get_queryset(self):
queryset = self.queryset
if hasattr(self, 'get_serialization_spec'):
self.serialization_spec = self.get_serialization_spec()
self.serialization_spec = get_serialization_spec(self)
if self.serialization_spec is None:
raise ImproperlyConfigured('SerializationSpecMixin requires serialization_spec or get_serialization_spec')
expand_many2many_id_fields(queryset.model, self.serialization_spec)
serialization_spec = expand_nested_specs(self.serialization_spec, self.request.user)
serialization_spec = normalise_spec(serialization_spec)
Expand Down
Loading