Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Custom Model Support and Configurable API URL #32

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions src/noteGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,24 +105,30 @@ export function noteGenerator(
}

const buildMessages = async (node: CanvasNode) => {
const encoding = getEncoding(settings)

const messages: openai.ChatCompletionRequestMessage[] = []
const isCustomModel = settings.apiModel === CHAT_MODELS.CUSTOMIZE.name
const encoding = isCustomModel ? null : getEncoding(settings)
let tokenCount = 0

// Note: We are not checking for system prompt longer than context window.
// That scenario makes no sense, though.
const systemPrompt = await getSystemPrompt(node)
if (systemPrompt) {
tokenCount += encoding.encode(systemPrompt).length
if (!isCustomModel && encoding) {
tokenCount += encoding.encode(systemPrompt).length
}
messages.push({
role: 'system',
content: systemPrompt
})
}

const visit = async (node: CanvasNode, depth: number) => {
if (settings.maxDepth && depth > settings.maxDepth) return false

const nodeData = node.getData()
let nodeText = (await readNodeContent(node))?.trim() || ''
const inputLimit = getTokenLimit(settings)
const inputLimit = isCustomModel ? Infinity : getTokenLimit(settings)

let shouldContinue = true
if (!nodeText) {
Expand All @@ -140,34 +146,36 @@ export function noteGenerator(
} else {
if (isSystemPromptNode(nodeText)) return true

let nodeTokens = encoding.encode(nodeText)
let nodeTokens = isCustomModel || !encoding ? null : encoding.encode(nodeText)
let keptNodeTokens: number

if (tokenCount + nodeTokens.length > inputLimit) {
if (!isCustomModel && nodeTokens && tokenCount + nodeTokens.length > inputLimit) {
// will exceed input limit

shouldContinue = false

// Leaving one token margin, just in case
const keepTokens = nodeTokens.slice(0, inputLimit - tokenCount - 1)
const truncateTextTo = encoding.decode(keepTokens).length
logDebug(
`Truncating node text from ${nodeText.length} to ${truncateTextTo} characters`
const keepTokens = Math.max(
0,
inputLimit - tokenCount - 1
)
nodeText = nodeText.slice(0, truncateTextTo)
keptNodeTokens = keepTokens.length
if (encoding) {
const keepBytes = encoding
.decode(nodeTokens.slice(0, keepTokens))
.length
nodeText = nodeText.slice(0, keepBytes)
}
keptNodeTokens = keepTokens
} else {
keptNodeTokens = nodeTokens.length
keptNodeTokens = nodeTokens?.length || 0
}

tokenCount += keptNodeTokens

const role: openai.ChatCompletionRequestMessageRoleEnum =
nodeData.chat_role === 'assistant' ? 'assistant' : 'user'
if (!isCustomModel) {
tokenCount += keptNodeTokens
}

messages.unshift({
content: nodeText,
role
role: 'user'
})
}

Expand All @@ -177,13 +185,6 @@ export function noteGenerator(
await visitNodeAndAncestors(node, visit)

if (messages.length) {
if (systemPrompt) {
messages.unshift({
content: systemPrompt,
role: 'system'
})
}

return { messages, tokenCount }
} else {
return { messages: [], tokenCount: 0 }
Expand Down Expand Up @@ -242,9 +243,10 @@ export function noteGenerator(
settings.apiModel,
messages,
{
max_tokens: settings.maxResponseTokens || undefined,
temperature: settings.temperature
}
temperature: settings.temperature,
max_tokens: settings.maxResponseTokens || undefined
},
settings.customModelName
)

if (generated == null) {
Expand Down Expand Up @@ -287,13 +289,19 @@ export function noteGenerator(
}

function getEncoding(settings: ChatStreamSettings) {
if (settings.apiModel === CHAT_MODELS.CUSTOMIZE.name) {
return null
}
const model: ChatModelSettings | undefined = chatModelByName(settings.apiModel)
return encodingForModel(
(model?.encodingFrom || model?.name || DEFAULT_SETTINGS.apiModel) as TiktokenModel
)
}

function getTokenLimit(settings: ChatStreamSettings) {
if (settings.apiModel === CHAT_MODELS.CUSTOMIZE.name) {
return Infinity
}
const model = chatModelByName(settings.apiModel) || CHAT_MODELS.GPT_35_TURBO_0125
return settings.maxInputTokens
? Math.min(settings.maxInputTokens, model.tokenLimit)
Expand Down
18 changes: 15 additions & 3 deletions src/openai/chatGPT.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { request, RequestUrlParam } from 'obsidian'
import { openai } from './chatGPT-types'

export const OPENAI_COMPLETIONS_URL = `https://api.openai.com/v1/chat/completions`
export const DEFAULT_OPENAI_COMPLETIONS_URL = `https://api.openai.com/v1/chat/completions`

export const getOpenAICompletionsURL = (configuredUrl?: string): string => {
return configuredUrl || DEFAULT_OPENAI_COMPLETIONS_URL
}

export type ChatModelSettings = {
name: string,
Expand Down Expand Up @@ -62,6 +66,10 @@ export const CHAT_MODELS = {
GPT_4_32K_0613: {
name: 'gpt-4-32k-0613',
tokenLimit: 32768
},
CUSTOMIZE: {
name: 'customize',
tokenLimit: 0
}
}

Expand Down Expand Up @@ -89,15 +97,19 @@ export async function getChatGPTCompletion(
messages: openai.CreateChatCompletionRequest['messages'],
settings?: Partial<
Omit<openai.CreateChatCompletionRequest, 'messages' | 'model'>
>
>,
customModelName?: string
): Promise<string | undefined> {
const headers = {
Authorization: `Bearer ${apiKey}`,
'Content-Type': 'application/json'
}
const actualModel = model === CHAT_MODELS.CUSTOMIZE.name && customModelName
? customModelName
: (model ?? CHAT_MODELS.GPT_35_TURBO.name)
const body: openai.CreateChatCompletionRequest = {
messages,
model,
model: actualModel,
...settings
}
const requestParam: RequestUrlParam = {
Expand Down
10 changes: 8 additions & 2 deletions src/settings/ChatStreamSettings.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CHAT_MODELS, OPENAI_COMPLETIONS_URL } from 'src/openai/chatGPT'
import { CHAT_MODELS, DEFAULT_OPENAI_COMPLETIONS_URL } from 'src/openai/chatGPT'

export interface ChatStreamSettings {
/**
Expand All @@ -16,6 +16,11 @@ export interface ChatStreamSettings {
*/
apiModel: string

/**
* Custom model name when using customize option
*/
customModelName: string

/**
* The temperature to use when generating responses (0-2). 0 means no randomness.
*/
Expand Down Expand Up @@ -57,8 +62,9 @@ Use step-by-step reasoning. Be brief.

export const DEFAULT_SETTINGS: ChatStreamSettings = {
apiKey: '',
apiUrl: OPENAI_COMPLETIONS_URL,
apiUrl: DEFAULT_OPENAI_COMPLETIONS_URL,
apiModel: CHAT_MODELS.GPT_35_TURBO.name,
customModelName: '',
temperature: 1,
systemPrompt: DEFAULT_SYSTEM_PROMPT,
debug: false,
Expand Down
18 changes: 18 additions & 0 deletions src/settings/SettingsTab.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { App, PluginSettingTab, Setting } from 'obsidian'
import { ChatStreamPlugin } from 'src/ChatStreamPlugin'
import { getModels } from './ChatStreamSettings'
import { CHAT_MODELS } from 'src/openai/chatGPT'

export class SettingsTab extends PluginSettingTab {
plugin: ChatStreamPlugin
Expand All @@ -26,9 +27,26 @@ export class SettingsTab extends PluginSettingTab {
cb.onChange(async (value) => {
this.plugin.settings.apiModel = value
await this.plugin.saveSettings()
// Trigger refresh to show/hide custom model input
this.display()
})
})

if (this.plugin.settings.apiModel === CHAT_MODELS.CUSTOMIZE.name) {
new Setting(containerEl)
.setName('Custom Model Name')
.setDesc('Enter the name of your custom model')
.addText((text) => {
text
.setPlaceholder('e.g., gpt-4-1106-vision-preview')
.setValue(this.plugin.settings.customModelName)
.onChange(async (value) => {
this.plugin.settings.customModelName = value
await this.plugin.saveSettings()
})
})
}

new Setting(containerEl)
.setName('API key')
.setDesc('The API key to use when making requests - Get from OpenAI')
Expand Down