Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/novelai_python/sdk/ai/generate_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,6 @@ def set_mutual_exclusion(self, value: bool):
action: Union[str, Action] = Field(Action.GENERATE, description="Mode for img generate")
parameters: Union[Params]
model_config = ConfigDict(extra="ignore")

# forced params integration
def model_dump(self, *args, **kwargs):
"""
Overrides model_dump for own features
"""
data = super().model_dump(*args, **kwargs)
data["parameters"] = self.parameters.model_dump(*args, **kwargs)

return data

@override
def model_post_init(self, *args) -> None:
Expand Down
24 changes: 18 additions & 6 deletions src/novelai_python/sdk/ai/generate_image/params.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import base64
import random
from io import BytesIO
from collections import OrderedDict
from typing import Optional, List, Set, Union, Tuple

import cv2
import numpy as np
from PIL import Image
from loguru import logger
from pydantic import BaseModel, Field, model_validator, field_validator, model_serializer
from pydantic import BaseModel, Field, model_serializer, model_validator, field_validator

from novelai_python.sdk.ai._enum import Sampler, UCPresetTypeAlias, NoiseSchedule, ImageBytesTypeAlias, ControlNetModel, \
Model
Expand Down Expand Up @@ -128,11 +129,22 @@ def _serialize(self, handler):
Custom serializer to force include specific fields even when they are None
"""
data = handler(self)
# Just add None values for strong fields
for field in self.__strong_values__:
if field not in data:
data[field] = getattr(self, field, None)
return data

# Force include strong fields and ensure proper ordering
# Create ordered dict following JavaScript field order
ordered_fields = list(self.__class__.model_fields.keys())
ordered_data = OrderedDict()

for field_name in ordered_fields:
if field_name in data:
# Field exists in data, include it
ordered_data[field_name] = data[field_name]
elif field_name in self.__strong_values__:
# Since we iterate from `self.__class__.model_fields`,、
# the field is guaranteed to exist. The previous safety check was redundant.
ordered_data[field_name] = getattr(self, field_name, None)

return ordered_data
# endregion

@model_validator(mode="after")
Expand Down