diff --git a/drf_excel/renderers.py b/drf_excel/renderers.py index 56727e9..ba29313 100644 --- a/drf_excel/renderers.py +++ b/drf_excel/renderers.py @@ -260,7 +260,7 @@ def _flatten_serializer_keys( use_labels=False, ): """ - Iterate through serializer fields recursively when field is a nested serializer. + Iterate through serializer fields recursively when field is a nested serializer. Skip write_only fields. """ def _get_label(parent_label, label_sep, obj): @@ -278,7 +278,7 @@ def _get_label(parent_label, label_sep, obj): for k, v in _fields.items(): new_key = f"{parent_key}{key_sep}{k}" if parent_key else k # Skip headers we want to ignore - if new_key in self.ignore_headers: + if new_key in self.ignore_headers or getattr(v, "write_only", False): continue # Iterate through fields if field is a serializer. Check for labels and # append if `use_labels` is True. Fallback to using keys. diff --git a/tests/test_viewset_mixin.py b/tests/test_viewset_mixin.py index eb0395c..7799258 100644 --- a/tests/test_viewset_mixin.py +++ b/tests/test_viewset_mixin.py @@ -4,7 +4,7 @@ from rest_framework.test import APIClient from time_machine import TimeMachineFixture -from tests.testapp.models import AllFieldsModel, ExampleModel, Tag +from tests.testapp.models import AllFieldsModel, ExampleModel, SecretFieldModel, Tag pytestmark = pytest.mark.django_db @@ -93,3 +93,20 @@ def test_all_fields_viewset( True, "test, example", ] + + +def test_secret_field_viewset(api_client, workbook_reader): + SecretFieldModel.objects.create(title="foo", secret="bar") + + response = api_client.get("/secret-field/") + assert response.status_code == 200 + + wb = workbook_reader(response.content) + sheet = wb.worksheets[0] + rows = list(sheet.rows) + assert len(rows) == 2 + header, data = rows + + # Check that the secret field is not included in the header or data + assert [col.value for col in header] == ["title"] + assert [col.value for col in data] == ["foo"] diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 22fb3d2..493a04f 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -30,3 +30,11 @@ def __str__(self): def get_tag_names(self): return [tag.name for tag in self.tags.all()] + + +class SecretFieldModel(models.Model): + title = models.CharField(max_length=100) + secret = models.CharField(max_length=100) + + def __str__(self): + return self.title diff --git a/tests/testapp/serializers.py b/tests/testapp/serializers.py index e788d2f..e433822 100644 --- a/tests/testapp/serializers.py +++ b/tests/testapp/serializers.py @@ -1,6 +1,6 @@ from rest_framework import serializers -from .models import AllFieldsModel, ExampleModel +from .models import AllFieldsModel, ExampleModel, SecretFieldModel class ExampleSerializer(serializers.ModelSerializer): @@ -23,3 +23,13 @@ class Meta: "is_active", "tags", ) + + +class SecretFieldSerializer(serializers.ModelSerializer): + secret_external = serializers.CharField(write_only=True) + + class Meta: + model = SecretFieldModel + fields = ("title", "secret", "secret_external") + + extra_kwargs = {"secret": {"write_only": True}} diff --git a/tests/testapp/views.py b/tests/testapp/views.py index f435cd4..71b7038 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -3,8 +3,8 @@ from drf_excel.mixins import XLSXFileMixin from drf_excel.renderers import XLSXRenderer -from .models import AllFieldsModel, ExampleModel -from .serializers import AllFieldsSerializer, ExampleSerializer +from .models import AllFieldsModel, ExampleModel, SecretFieldModel +from .serializers import AllFieldsSerializer, ExampleSerializer, SecretFieldSerializer class ExampleViewSet(XLSXFileMixin, ReadOnlyModelViewSet): @@ -19,3 +19,10 @@ class AllFieldsViewSet(XLSXFileMixin, ReadOnlyModelViewSet): serializer_class = AllFieldsSerializer renderer_classes = (XLSXRenderer,) filename = "al_fileds.xlsx" + + +class SecretFieldViewSet(XLSXFileMixin, ReadOnlyModelViewSet): + queryset = SecretFieldModel.objects.all() + serializer_class = SecretFieldSerializer + renderer_classes = (XLSXRenderer,) + filename = "secret.xlsx" diff --git a/tests/urls.py b/tests/urls.py index 30a06d5..d875c66 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,9 +1,10 @@ from rest_framework import routers -from .testapp.views import AllFieldsViewSet, ExampleViewSet +from .testapp.views import AllFieldsViewSet, ExampleViewSet, SecretFieldViewSet router = routers.SimpleRouter() router.register(r"examples", ExampleViewSet) router.register(r"all-fields", AllFieldsViewSet) +router.register(r"secret-field", SecretFieldViewSet) urlpatterns = router.urls