Skip to content

Commit

Permalink
πŸ„β€β™‚οΈ refactor: Optimize Reasoning UI & Token Streaming (#5546)
Browse files Browse the repository at this point in the history
* ✨ feat: Implement Show Thinking feature; refactor: testing thinking render optimizations

* ✨ feat: Refactor Thinking component styles and enhance Markdown rendering

* chore: add back removed code, revert type changes

* chore: Add back resetCounter effect to Markdown component for improved code block indexing

* chore: bump @librechat/agents and google langchain packages

* WIP: reasoning type updates

* WIP: first pass, reasoning content blocks

* chore: revert code

* chore: bump @librechat/agents

* refactor: optimize reasoning tag handling

* style: ul indent padding

* feat: add Reasoning component to handle reasoning display

* feat: first pass, content reasoning part styling

* refactor: add content placeholder for endpoints using new stream handler

* refactor: only cache messages when requesting stream audio

* fix: circular dep.

* fix: add default param

* refactor: tts, only request after message stream, fix chrome autoplay

* style: update label for submitting state and add localization for 'Thinking...'

* fix: improve global audio pause logic and reset active run ID

* fix: handle artifact edge cases

* fix: remove unnecessary console log from artifact update test

* feat: add support for continued message handling with new streaming method

---------

Co-authored-by: Marco Beretta <[email protected]>
  • Loading branch information
danny-avila and berry-13 authored Jan 30, 2025
1 parent d60a149 commit 591a019
Show file tree
Hide file tree
Showing 48 changed files with 1,792 additions and 727 deletions.
21 changes: 7 additions & 14 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ const {
EModelEndpoint,
ErrorTypes,
Constants,
CacheKeys,
Time,
} = require('librechat-data-provider');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const { truncateToolCallOutputs } = require('./prompts');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
const { getLogStores } = require('~/cache');
const TextStream = require('./TextStream');
const { logger } = require('~/config');

Expand Down Expand Up @@ -54,6 +51,12 @@ class BaseClient {
this.outputTokensKey = 'completion_tokens';
/** @type {Set<string>} */
this.savedMessageIds = new Set();
/**
* Flag to determine if the client re-submitted the latest assistant message.
* @type {boolean | undefined} */
this.continued;
/** @type {TMessage[]} */
this.currentMessages = [];
}

setOptions() {
Expand Down Expand Up @@ -589,6 +592,7 @@ class BaseClient {
} else {
latestMessage.text = generation;
}
this.continued = true;
} else {
this.currentMessages.push(userMessage);
}
Expand Down Expand Up @@ -720,17 +724,6 @@ class BaseClient {

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.savedMessageIds.add(responseMessage.messageId);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return responseMessage;
}
Expand Down
116 changes: 81 additions & 35 deletions api/app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const OpenAI = require('openai');
const { OllamaClient } = require('./OllamaClient');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { SplitStreamHandler, GraphEvents } = require('@librechat/agents');
const {
Constants,
ImageDetail,
Expand Down Expand Up @@ -28,17 +29,17 @@ const {
createContextHandlers,
} = require('./prompts');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils');
const Tokenizer = require('~/server/services/Tokenizer');
const { spendTokens } = require('~/models/spendTokens');
const { isEnabled, sleep } = require('~/server/utils');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const { logger, sendEvent } = require('~/config');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');

class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
Expand All @@ -65,6 +66,8 @@ class OpenAIClient extends BaseClient {
this.usage;
/** @type {boolean|undefined} */
this.isO1Model;
/** @type {SplitStreamHandler | undefined} */
this.streamHandler;
}

// TODO: PluginsClient calls this 3x, unneeded
Expand Down Expand Up @@ -1064,11 +1067,36 @@ ${convo}
});
}

getStreamText() {
if (!this.streamHandler) {
return '';
}

const reasoningTokens =
this.streamHandler.reasoningTokens.length > 0
? `:::thinking\n${this.streamHandler.reasoningTokens.join('')}\n:::\n`
: '';

return `${reasoningTokens}${this.streamHandler.tokens.join('')}`;
}

getMessageMapMethod() {
/**
* @param {TMessage} msg
*/
return (msg) => {
if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) {
msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim();
}

return msg;
};
}

async chatCompletion({ payload, onProgress, abortController = null }) {
let error = null;
let intermediateReply = [];
const errorCallback = (err) => (error = err);
const intermediateReply = [];
const reasoningTokens = [];
try {
if (!abortController) {
abortController = new AbortController();
Expand Down Expand Up @@ -1266,6 +1294,19 @@ ${convo}
reasoningKey = 'reasoning';
}

this.streamHandler = new SplitStreamHandler({
reasoningKey,
accumulate: true,
runId: this.responseMessageId,
handlers: {
[GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event),
[GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event),
},
});

intermediateReply = this.streamHandler.tokens;

if (modelOptions.stream) {
streamPromise = new Promise((resolve) => {
streamResolve = resolve;
Expand All @@ -1292,41 +1333,36 @@ ${convo}
}

if (typeof finalMessage.content !== 'string' || finalMessage.content.trim() === '') {
finalChatCompletion.choices[0].message.content = intermediateReply.join('');
finalChatCompletion.choices[0].message.content = this.streamHandler.tokens.join('');
}
})
.on('finalMessage', (message) => {
if (message?.role !== 'assistant') {
stream.messages.push({ role: 'assistant', content: intermediateReply.join('') });
stream.messages.push({
role: 'assistant',
content: this.streamHandler.tokens.join(''),
});
UnexpectedRoleError = true;
}
});

let reasoningCompleted = false;
for await (const chunk of stream) {
if (chunk?.choices?.[0]?.delta?.[reasoningKey]) {
if (reasoningTokens.length === 0) {
const thinkingDirective = '<think>\n';
intermediateReply.push(thinkingDirective);
reasoningTokens.push(thinkingDirective);
onProgress(thinkingDirective);
}
const reasoning_content = chunk?.choices?.[0]?.delta?.[reasoningKey] || '';
intermediateReply.push(reasoning_content);
reasoningTokens.push(reasoning_content);
onProgress(reasoning_content);
}

const token = chunk?.choices?.[0]?.delta?.content || '';
if (!reasoningCompleted && reasoningTokens.length > 0 && token) {
reasoningCompleted = true;
const separatorTokens = '\n</think>\n';
reasoningTokens.push(separatorTokens);
onProgress(separatorTokens);
}
if (this.continued === true) {
const latestText = addSpaceIfNeeded(
this.currentMessages[this.currentMessages.length - 1]?.text ?? '',
);
this.streamHandler.handle({
choices: [
{
delta: {
content: latestText,
},
},
],
});
}

intermediateReply.push(token);
onProgress(token);
for await (const chunk of stream) {
this.streamHandler.handle(chunk);
if (abortController.signal.aborted) {
stream.controller.abort();
break;
Expand Down Expand Up @@ -1369,7 +1405,7 @@ ${convo}

if (!Array.isArray(choices) || choices.length === 0) {
logger.warn('[OpenAIClient] Chat completion response has no choices');
return intermediateReply.join('');
return this.streamHandler.tokens.join('');
}

const { message, finish_reason } = choices[0] ?? {};
Expand All @@ -1379,20 +1415,30 @@ ${convo}

if (!message) {
logger.warn('[OpenAIClient] Message is undefined in chatCompletion response');
return intermediateReply.join('');
return this.streamHandler.tokens.join('');
}

if (typeof message.content !== 'string' || message.content.trim() === '') {
const reply = intermediateReply.join('');
const reply = this.streamHandler.tokens.join('');
logger.debug(
'[OpenAIClient] chatCompletion: using intermediateReply due to empty message.content',
{ intermediateReply: reply },
);
return reply;
}

if (reasoningTokens.length > 0 && this.options.context !== 'title') {
return reasoningTokens.join('') + message.content;
if (
this.streamHandler.reasoningTokens.length > 0 &&
this.options.context !== 'title' &&
!message.content.startsWith('<think>')
) {
return this.getStreamText();
} else if (
this.streamHandler.reasoningTokens.length > 0 &&
this.options.context !== 'title' &&
message.content.startsWith('<think>')
) {
return message.content.replace('<think>', ':::thinking').replace('</think>', ':::');
}

return message.content;
Expand Down
13 changes: 0 additions & 13 deletions api/app/clients/PluginsClient.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const OpenAIClient = require('./OpenAIClient');
const { CacheKeys, Time } = require('librechat-data-provider');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
Expand All @@ -11,7 +10,6 @@ const checkBalance = require('~/models/checkBalance');
const { isEnabled } = require('~/server/utils');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { getLogStores } = require('~/cache');
const { logger } = require('~/config');

class PluginsClient extends OpenAIClient {
Expand Down Expand Up @@ -256,17 +254,6 @@ class PluginsClient extends OpenAIClient {
}

this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
if (responseMessage.text) {
const messageCache = getLogStores(CacheKeys.MESSAGES);
messageCache.set(
responseMessage.messageId,
{
text: responseMessage.text,
complete: true,
},
Time.FIVE_MINUTES,
);
}
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}
Expand Down
15 changes: 15 additions & 0 deletions api/config/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,22 @@ async function getMCPManager() {
return mcpManager;
}

/**
* Sends message data in Server Sent Events format.
* @param {ServerResponse} res - The server response.
* @param {{ data: string | Record<string, unknown>, event?: string }} event - The message event.
* @param {string} event.event - The type of event.
* @param {string} event.data - The message to be sent.
*/
const sendEvent = (res, event) => {
if (typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
};

module.exports = {
logger,
sendEvent,
getMCPManager,
};
6 changes: 3 additions & 3 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
"@keyv/redis": "^2.8.1",
"@langchain/community": "^0.3.14",
"@langchain/core": "^0.3.18",
"@langchain/google-genai": "^0.1.6",
"@langchain/google-vertexai": "^0.1.6",
"@langchain/google-genai": "^0.1.7",
"@langchain/google-vertexai": "^0.1.8",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^1.9.94",
"@librechat/agents": "^1.9.97",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.7.7",
"bcryptjs": "^2.4.3",
Expand Down
34 changes: 4 additions & 30 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
const throttle = require('lodash/throttle');
const { getResponseSender, Constants, CacheKeys, Time } = require('librechat-data-provider');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { getLogStores } = require('~/cache');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');

Expand Down Expand Up @@ -57,41 +55,17 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {

try {
const { client } = await initializeClient({ req, res, endpointOption });
const messageCache = getLogStores(CacheKeys.MESSAGES);
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: throttle(
({ text: partialText }) => {
/*
const unfinished = endpointOption.endpoint === EModelEndpoint.google ? false : true;
messageCache.set(responseMessageId, {
messageId: responseMessageId,
sender,
conversationId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: partialText,
model: client.modelOptions.model,
unfinished,
error: false,
user,
}, Time.FIVE_MINUTES);
*/

messageCache.set(responseMessageId, partialText, Time.FIVE_MINUTES);
},
3000,
{ trailing: false },
),
});
const { onProgress: progressCallback, getPartialText } = createOnProgress();

getText = getPartialText;
getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText;

const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
text: getText(),
userMessage,
promptTokens,
});
Expand Down
Loading

0 comments on commit 591a019

Please sign in to comment.