diff --git a/netbox_custom_objects/api/serializers.py b/netbox_custom_objects/api/serializers.py index 0df23e4..a18f90e 100644 --- a/netbox_custom_objects/api/serializers.py +++ b/netbox_custom_objects/api/serializers.py @@ -8,6 +8,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError from rest_framework.reverse import reverse +from rest_framework.utils import model_meta from netbox_custom_objects import field_types from netbox_custom_objects.models import (CustomObject, CustomObjectType, @@ -253,6 +254,44 @@ def get_display(self, obj): """Get display representation of the object""" return str(obj) + # Stock DRF create() without raise_errors_on_nested_writes guard + def create(self, validated_data): + ModelClass = self.Meta.model + + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) + + instance = ModelClass._default_manager.create(**validated_data) + + if many_to_many: + for field_name, value in many_to_many.items(): + field = getattr(instance, field_name) + field.set(value) + + return instance + + # Stock DRF update() with custom field.set() for M2M + def update(self, instance, validated_data): + info = model_meta.get_field_info(instance) + + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in info.relations and info.relations[attr].to_many: + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + + for attr, value in m2m_fields: + field = getattr(instance, attr) + field.set(value, clear=True) + + return instance + # Create basic attributes for the serializer attrs = { "Meta": meta, @@ -261,6 +300,8 @@ def get_display(self, obj): "get_url": get_url, "display": serializers.SerializerMethodField(), "get_display": get_display, + "create": create, + "update": update, } for field in model_fields: