Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions playground/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# @File : __init__.py.py

import asyncio
import base64
import os
import pathlib
import random
Expand Down Expand Up @@ -55,7 +56,7 @@ async def generate(
scene = prompt_generator.generate_scene_composition()
if prompt is None:
prompt = scene.pop(0) + ','.join([
'muelsyse (arknights) '
'muelsyse (arknights)'
])
character = [
Character(
Expand All @@ -74,11 +75,11 @@ async def generate(
# Generate
try:
agent = GenerateImageInfer.build_generate(
prompt=prompt,
width=832,
height=1216,
prompt=os.getenv("TEST_TAG", prompt),
width=1024,
height=1024,
model=model,
character_prompts=character,
character_prompts=None if os.getenv("TEST_TAG") else character,
sampler=Sampler.K_EULER_ANCESTRAL,
ucPreset=UCPreset.TYPE0,
# Recommended, using preset negative_prompt depends on selected model
Expand Down Expand Up @@ -110,7 +111,7 @@ async def direct_use():
:return:
"""
credential = ApiCredential(api_token=SecretStr("pst-5555"))
result = await GenerateImageInfer(
result: ImageGenerateResp = await GenerateImageInfer(
input="1girl",
model=Model.NAI_DIFFUSION_4_5_FULL,
parameters=Params(
Expand All @@ -127,6 +128,9 @@ async def direct_use():
)
).request(session=credential)
print(f"Meta: {result.meta}")
file = result.files[0]
with open(f"{pathlib.Path(__file__).stem}.png", "wb") as f:
f.write(file[1])


load_dotenv()
Expand Down
4 changes: 2 additions & 2 deletions playground/generate_image_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from novelai_python import APIError, LoginCredential
from novelai_python import GenerateImageInfer, ImageGenerateResp, ApiCredential
from novelai_python.sdk.ai._enum import Model
from novelai_python.sdk.ai.generate_image import Action, Sampler
from novelai_python.utils.useful import enum_to_list

Expand Down Expand Up @@ -42,11 +43,10 @@ async def generate(
image = base64.b64encode(f.read()).decode()
# image = f.read() # Or you can use the raw bytes
agent = GenerateImageInfer.build_img2img(
model=Model.NAI_DIFFUSION_4_5_FULL,
prompt=prompt,
sampler=Sampler.K_DPMPP_SDE,
image=image,
seed=123456789,
extra_noise_seed=123123123,
)
print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3")
print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.7.10"
version = "0.7.11"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "coldlando@hotmail.com" },
Expand All @@ -13,7 +13,7 @@ dependencies = [
"httpx>=0.26.0",
"shortuuid>=1.0.11",
"Pillow>=10.2.0",
"curl-cffi>=0.9.0",
"curl-cffi>=0.11.3",
"fastapi>=0.109.0",
"uvicorn[standard]>=0.27.0.post1",
"numpy>=1.24.4",
Expand Down
11 changes: 3 additions & 8 deletions src/novelai_python/credential/ApiToken.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger
from pydantic import SecretStr, Field, field_validator

from ._base import CredentialBase, FAKE_UA
from ._base import CredentialBase


class ApiCredential(CredentialBase):
Expand All @@ -22,13 +22,8 @@ class ApiCredential(CredentialBase):

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
"User-Agent": FAKE_UA.edge,
"Authorization": f"Bearer {self.api_token.get_secret_value()}",
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self.x_correlation_id,
"x-initiated-at": f"{arrow.utcnow().isoformat()}Z",
}
Expand All @@ -37,10 +32,10 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
assert isinstance(update_headers, dict), "update_headers must be a dict"
headers.update(update_headers)

return AsyncSession(timeout=timeout, headers=headers, impersonate="edge101")
return AsyncSession(timeout=timeout, headers=headers, impersonate="chrome136")

@field_validator('api_token')
def check_api_token(cls, v: SecretStr):
if not v.get_secret_value().startswith("pst"):
logger.warning("api token should start with `pst-`")
return v
return v
11 changes: 3 additions & 8 deletions src/novelai_python/credential/JwtToken.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger
from pydantic import SecretStr, Field, field_validator

from ._base import CredentialBase, FAKE_UA
from ._base import CredentialBase


class JwtCredential(CredentialBase):
Expand All @@ -22,13 +22,8 @@ class JwtCredential(CredentialBase):

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
"Accept": "*/*",
"User-Agent": FAKE_UA.edge,
"Accept-Encoding": "gzip, deflate, br",
"Authorization": f"Bearer {self.jwt_token.get_secret_value()}",
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self.x_correlation_id,
"x-initiated-at": f"{arrow.utcnow().isoformat()}Z",
}
Expand All @@ -37,10 +32,10 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
assert isinstance(update_headers, dict), "update_headers must be a dict"
headers.update(update_headers)

return AsyncSession(timeout=timeout, headers=headers, impersonate="edge101")
return AsyncSession(timeout=timeout, headers=headers, impersonate="chrome136")

@field_validator('jwt_token')
def check_jwt_token(cls, v: SecretStr):
if not v.get_secret_value().startswith("ey"):
logger.warning("jwt_token should start with ey")
return v
return v
9 changes: 2 additions & 7 deletions src/novelai_python/credential/UserAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from curl_cffi.requests import AsyncSession
from pydantic import SecretStr, Field

from ._base import CredentialBase, FAKE_UA
from ._base import CredentialBase


class LoginCredential(CredentialBase):
Expand All @@ -25,13 +25,8 @@ class LoginCredential(CredentialBase):

async def get_session(self, timeout: int = 180, update_headers: dict = None):
headers = {
"Accept": "*/*",
"User-Agent": FAKE_UA.edge,
"Accept-Encoding": "gzip, deflate, br",
"Authorization": "Bearer ",
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Referer": "https://novelai.net/",
"x-correlation-id": self.x_correlation_id,
"x-initiated-at": f"{arrow.utcnow().isoformat()}Z",
}
Expand All @@ -49,4 +44,4 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
if update_headers:
headers.update(update_headers)

return AsyncSession(timeout=timeout, headers=headers, impersonate="edge101")
return AsyncSession(timeout=timeout, headers=headers, impersonate="chrome136")
7 changes: 6 additions & 1 deletion src/novelai_python/sdk/ai/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,17 @@ class SupportCondition:
img2imgInpainting: bool


def get_supported_params(model: Model):
def get_supported_params(model: ModelTypeAlias):
"""
Get supported parameters for a given model
:param model: Model
:return: SupportCondition
"""
if isinstance(model, str):
try:
model = Model(model)
except ValueError:
pass
if model in [
Model.STABLE_DIFFUSION,
Model.NAI_DIFFUSION,
Expand Down
23 changes: 0 additions & 23 deletions src/novelai_python/sdk/ai/augment_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,28 +174,6 @@ def build(cls,
prompt=prompt,
)

async def necessary_headers(self, request_data) -> dict:
"""
:param request_data:
:return:
"""
return {
"Host": urlparse(self.endpoint).netloc,
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate, br",
"Referer": "https://novelai.net/",
"Content-Type": "application/json",
"Origin": "https://novelai.net",
"Content-Length": str(len(json.dumps(request_data).encode("utf-8"))),
"Connection": "keep-alive",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "same-site",
"Pragma": "no-cache",
"Cache-Control": "no-cache",
'priority': "u=1, i"
}

@retry(
wait=wait_random(min=1, max=3),
stop=stop_after_attempt(3),
Expand All @@ -216,7 +194,6 @@ async def request(self,
# Prepare request data
request_data = self.model_dump(mode="json", exclude_none=True)
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)
Expand Down
17 changes: 0 additions & 17 deletions src/novelai_python/sdk/ai/generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def endpoint(self, value):
def base_url(self):
return f"{self.endpoint.strip('/')}/ai/generate"

async def necessary_headers(self, request_data) -> dict:
"""
:param request_data: dict
:return: dict
"""
return {
"Host": urlparse(self.endpoint).netloc,
"accept": "*/*",
"accept-language": "zh-CN,zh;q=0.9",
"cache-control": "no-cache",
"content-type": "application/json",
"pragma": "no-cache",
"Referer": "https://novelai.net/",
"Referrer-Policy": "strict-origin-when-cross-origin"
}

@model_validator(mode="after")
def normalize_model(self):
if self.model in [
Expand Down Expand Up @@ -361,7 +345,6 @@ async def request(self,
}
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
# Header
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)
Expand Down
Loading