551 lines
16 KiB
TypeScript
551 lines
16 KiB
TypeScript
/**
|
|
* OpenRouter AI provider implementation
|
|
*/
|
|
|
|
import OpenAI from "openai";
|
|
import type {
|
|
ChatCompletionMessageParam,
|
|
ChatCompletionMessageToolCall,
|
|
ChatCompletionTool,
|
|
} from "openai/resources/chat/completions";
|
|
import { config } from "../../core/config";
|
|
import { createLogger } from "../../core/logger";
|
|
import type { AiProvider, AiResponse, AskOptions, AskWithToolsOptions, MessageStyle, TextStreamHandler } from "./types";
|
|
import { MEMORY_EXTRACTION_TOOLS, getToolsForContext, type ToolCall, type ToolContext } from "./tools";
|
|
import { executeTools } from "./tool-handlers";
|
|
|
|
const logger = createLogger("AI:OpenRouter");
|
|
|
|
// Style classification options
|
|
const STYLE_OPTIONS: MessageStyle[] = ["story", "snarky", "insult", "explicit", "helpful"];
|
|
|
|
// Maximum tool call iterations to prevent infinite loops
|
|
const MAX_TOOL_ITERATIONS = 5;
|
|
|
|
interface StreamedToolCall {
|
|
id: string;
|
|
type: "function";
|
|
function: {
|
|
name: string;
|
|
arguments: string;
|
|
};
|
|
}
|
|
|
|
interface StreamedCompletionResult {
|
|
text: string;
|
|
toolCalls: StreamedToolCall[];
|
|
}
|
|
|
|
export class OpenRouterProvider implements AiProvider {
|
|
private client: OpenAI;
|
|
|
|
constructor() {
|
|
this.client = new OpenAI({
|
|
baseURL: "https://openrouter.ai/api/v1",
|
|
apiKey: config.ai.openRouterApiKey,
|
|
defaultHeaders: {
|
|
"HTTP-Referer": "https://github.com/crunk-bun",
|
|
"X-Title": "Joel Discord Bot",
|
|
},
|
|
});
|
|
}
|
|
|
|
async health(): Promise<boolean> {
|
|
try {
|
|
// Simple health check - verify we can list models
|
|
await this.client.models.list();
|
|
return true;
|
|
} catch (error) {
|
|
logger.error("Health check failed", error);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
private getClassificationModelCandidates(): string[] {
|
|
const models = [
|
|
config.ai.classificationModel,
|
|
...config.ai.classificationFallbackModels,
|
|
config.ai.model,
|
|
];
|
|
|
|
return Array.from(new Set(models.map((model) => model.trim()).filter((model) => model.length > 0)));
|
|
}
|
|
|
|
private getErrorStatus(error: unknown): number | undefined {
|
|
const err = error as {
|
|
status?: number;
|
|
code?: number | string;
|
|
error?: { code?: number | string };
|
|
};
|
|
|
|
const code = typeof err.code === "number"
|
|
? err.code
|
|
: typeof err.error?.code === "number"
|
|
? err.error.code
|
|
: undefined;
|
|
|
|
return err.status ?? code;
|
|
}
|
|
|
|
async ask(options: AskOptions): Promise<AiResponse> {
|
|
const { prompt, systemPrompt, maxTokens, temperature, onTextStream } = options;
|
|
const model = config.ai.model;
|
|
|
|
try {
|
|
if (onTextStream) {
|
|
const streamed = await this.streamChatCompletion({
|
|
model,
|
|
messages: [
|
|
{ role: "system", content: systemPrompt },
|
|
{ role: "user", content: prompt },
|
|
],
|
|
max_tokens: maxTokens ?? config.ai.maxTokens,
|
|
temperature: temperature ?? config.ai.temperature,
|
|
}, onTextStream);
|
|
|
|
return { text: streamed.text };
|
|
}
|
|
|
|
const completion = await this.client.chat.completions.create({
|
|
model,
|
|
messages: [
|
|
{ role: "system", content: systemPrompt },
|
|
{ role: "user", content: prompt },
|
|
],
|
|
max_tokens: maxTokens ?? config.ai.maxTokens,
|
|
temperature: temperature ?? config.ai.temperature,
|
|
});
|
|
|
|
const text = completion.choices[0]?.message?.content ?? "";
|
|
return { text };
|
|
} catch (error: unknown) {
|
|
logger.error("Failed to generate response (ask)", {
|
|
method: "ask",
|
|
model,
|
|
promptLength: prompt.length,
|
|
...(error instanceof Error ? {} : { rawError: error }),
|
|
});
|
|
logger.error("API error details", error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Generate a response with tool calling support
|
|
* The AI can call tools (like looking up memories) during response generation
|
|
*/
|
|
async askWithTools(options: AskWithToolsOptions): Promise<AiResponse> {
|
|
const { prompt, systemPrompt, context, maxTokens, temperature, onTextStream } = options;
|
|
|
|
const messages: ChatCompletionMessageParam[] = [
|
|
{ role: "system", content: systemPrompt },
|
|
{ role: "user", content: prompt },
|
|
];
|
|
|
|
// Get the appropriate tools for this context (includes optional tools like GIF search)
|
|
const tools = getToolsForContext(context);
|
|
|
|
let iterations = 0;
|
|
|
|
while (iterations < MAX_TOOL_ITERATIONS) {
|
|
iterations++;
|
|
|
|
try {
|
|
if (onTextStream) {
|
|
const streamed = await this.streamChatCompletion({
|
|
model: config.ai.model,
|
|
messages,
|
|
tools,
|
|
tool_choice: "auto",
|
|
max_tokens: maxTokens ?? config.ai.maxTokens,
|
|
temperature: temperature ?? config.ai.temperature,
|
|
}, onTextStream);
|
|
|
|
if (streamed.toolCalls.length > 0) {
|
|
logger.debug("AI requested tool calls", {
|
|
count: streamed.toolCalls.length,
|
|
tools: streamed.toolCalls.map((tc) => tc.function.name),
|
|
});
|
|
|
|
messages.push({
|
|
role: "assistant",
|
|
content: streamed.text || null,
|
|
tool_calls: streamed.toolCalls,
|
|
});
|
|
|
|
await onTextStream("");
|
|
|
|
const toolCalls = this.parseToolCalls(streamed.toolCalls);
|
|
const results = await executeTools(toolCalls, context);
|
|
|
|
for (let i = 0; i < toolCalls.length; i++) {
|
|
messages.push({
|
|
role: "tool",
|
|
tool_call_id: toolCalls[i].id,
|
|
content: results[i].result,
|
|
});
|
|
}
|
|
|
|
continue;
|
|
}
|
|
|
|
logger.debug("AI response generated", {
|
|
iterations,
|
|
textLength: streamed.text.length,
|
|
streamed: true,
|
|
});
|
|
|
|
return { text: streamed.text };
|
|
}
|
|
|
|
const completion = await this.client.chat.completions.create({
|
|
model: config.ai.model,
|
|
messages,
|
|
tools,
|
|
tool_choice: "auto",
|
|
max_tokens: maxTokens ?? config.ai.maxTokens,
|
|
temperature: temperature ?? config.ai.temperature,
|
|
});
|
|
|
|
const choice = completion.choices[0];
|
|
const message = choice?.message;
|
|
|
|
if (!message) {
|
|
logger.warn("No message in completion");
|
|
return { text: "" };
|
|
}
|
|
|
|
// Check if the AI wants to call tools
|
|
if (message.tool_calls && message.tool_calls.length > 0) {
|
|
logger.debug("AI requested tool calls", {
|
|
count: message.tool_calls.length,
|
|
tools: message.tool_calls.map(tc => tc.function.name)
|
|
});
|
|
|
|
// Add the assistant's message with tool calls
|
|
messages.push(message);
|
|
|
|
// Parse and execute tool calls
|
|
const toolCalls: ToolCall[] = message.tool_calls.map((tc) => ({
|
|
id: tc.id,
|
|
name: tc.function.name,
|
|
arguments: JSON.parse(tc.function.arguments || "{}"),
|
|
}));
|
|
|
|
const results = await executeTools(toolCalls, context);
|
|
|
|
// Add tool results as messages
|
|
for (let i = 0; i < toolCalls.length; i++) {
|
|
messages.push({
|
|
role: "tool",
|
|
tool_call_id: toolCalls[i].id,
|
|
content: results[i].result,
|
|
});
|
|
}
|
|
|
|
// Continue the loop to get the AI's response after tool execution
|
|
continue;
|
|
}
|
|
|
|
// No tool calls - we have a final response
|
|
const text = message.content ?? "";
|
|
logger.debug("AI response generated", {
|
|
iterations,
|
|
textLength: text.length
|
|
});
|
|
|
|
return { text };
|
|
} catch (error: unknown) {
|
|
logger.error("Failed to generate response with tools (askWithTools)", {
|
|
method: "askWithTools",
|
|
model: config.ai.model,
|
|
iteration: iterations,
|
|
messageCount: messages.length,
|
|
toolCount: tools.length,
|
|
...(error instanceof Error ? {} : { rawError: error }),
|
|
});
|
|
logger.error("API error details", error);
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
logger.warn("Max tool iterations reached");
|
|
return { text: "I got stuck in a loop thinking about that..." };
|
|
}
|
|
|
|
private async streamChatCompletion(
|
|
params: {
|
|
model: string;
|
|
messages: ChatCompletionMessageParam[];
|
|
tools?: ChatCompletionTool[];
|
|
tool_choice?: "auto" | "none";
|
|
max_tokens: number;
|
|
temperature: number;
|
|
},
|
|
onTextStream: TextStreamHandler,
|
|
): Promise<StreamedCompletionResult> {
|
|
const stream = await this.client.chat.completions.create({
|
|
...params,
|
|
stream: true,
|
|
});
|
|
|
|
let text = "";
|
|
const toolCalls = new Map<number, StreamedToolCall>();
|
|
|
|
for await (const chunk of stream) {
|
|
const choice = chunk.choices[0];
|
|
if (!choice) {
|
|
continue;
|
|
}
|
|
|
|
const delta = choice.delta;
|
|
const content = delta.content ?? "";
|
|
if (content) {
|
|
text += content;
|
|
await onTextStream(text);
|
|
}
|
|
|
|
for (const toolCallDelta of delta.tool_calls ?? []) {
|
|
const current = toolCalls.get(toolCallDelta.index) ?? {
|
|
id: "",
|
|
type: "function" as const,
|
|
function: {
|
|
name: "",
|
|
arguments: "",
|
|
},
|
|
};
|
|
|
|
if (toolCallDelta.id) {
|
|
current.id = toolCallDelta.id;
|
|
}
|
|
|
|
if (toolCallDelta.function?.name) {
|
|
current.function.name = toolCallDelta.function.name;
|
|
}
|
|
|
|
if (toolCallDelta.function?.arguments) {
|
|
current.function.arguments += toolCallDelta.function.arguments;
|
|
}
|
|
|
|
toolCalls.set(toolCallDelta.index, current);
|
|
}
|
|
}
|
|
|
|
return {
|
|
text,
|
|
toolCalls: Array.from(toolCalls.entries())
|
|
.sort((a, b) => a[0] - b[0])
|
|
.map(([, toolCall]) => toolCall),
|
|
};
|
|
}
|
|
|
|
private parseToolCalls(toolCalls: ChatCompletionMessageToolCall[]): ToolCall[] {
|
|
return toolCalls.map((toolCall) => {
|
|
try {
|
|
return {
|
|
id: toolCall.id,
|
|
name: toolCall.function.name,
|
|
arguments: JSON.parse(toolCall.function.arguments || "{}"),
|
|
};
|
|
} catch (error) {
|
|
logger.error("Failed to parse streamed tool call arguments", {
|
|
toolName: toolCall.function.name,
|
|
toolCallId: toolCall.id,
|
|
arguments: toolCall.function.arguments,
|
|
error,
|
|
});
|
|
throw error;
|
|
}
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Analyze a message to extract memorable information
|
|
*/
|
|
async extractMemories(
|
|
message: string,
|
|
authorName: string,
|
|
context: ToolContext
|
|
): Promise<void> {
|
|
const systemPrompt = `You are analyzing a Discord message to determine if it contains any memorable or useful information about the user "${authorName}".
|
|
|
|
Look for:
|
|
- Personal information (name, age, location, job, hobbies)
|
|
- Preferences (likes, dislikes, favorites)
|
|
- Embarrassing admissions or confessions
|
|
- Strong opinions or hot takes
|
|
- Achievements or accomplishments
|
|
- Relationships or social information
|
|
- Recurring patterns or habits
|
|
|
|
If you find something worth remembering, use the extract_memory tool. Only extract genuinely interesting or useful information - don't save trivial things.
|
|
|
|
The user's Discord ID is: ${context.userId}`;
|
|
|
|
try {
|
|
const completion = await this.client.chat.completions.create({
|
|
model: config.ai.model, // Use main model - needs tool support
|
|
messages: [
|
|
{ role: "system", content: systemPrompt },
|
|
{ role: "user", content: `Analyze this message for memorable content:\n\n"${message}"` },
|
|
],
|
|
tools: MEMORY_EXTRACTION_TOOLS,
|
|
tool_choice: "auto",
|
|
max_tokens: 200,
|
|
temperature: 0.3,
|
|
});
|
|
|
|
const toolCalls = completion.choices[0]?.message?.tool_calls;
|
|
|
|
if (toolCalls && toolCalls.length > 0) {
|
|
const parsedCalls: ToolCall[] = toolCalls.map((tc) => ({
|
|
id: tc.id,
|
|
name: tc.function.name,
|
|
arguments: JSON.parse(tc.function.arguments || "{}"),
|
|
}));
|
|
|
|
await executeTools(parsedCalls, context);
|
|
logger.debug("Memory extraction complete", {
|
|
extracted: parsedCalls.length,
|
|
authorName
|
|
});
|
|
}
|
|
} catch (error) {
|
|
// Don't throw - memory extraction is non-critical
|
|
logger.error("Memory extraction failed", {
|
|
method: "extractMemories",
|
|
model: config.ai.model,
|
|
authorName,
|
|
messageLength: message.length,
|
|
});
|
|
logger.error("Memory extraction error details", error);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Classify a message to determine the appropriate response style
|
|
*/
|
|
async classifyMessage(message: string): Promise<MessageStyle> {
|
|
const models = this.getClassificationModelCandidates();
|
|
|
|
for (let i = 0; i < models.length; i++) {
|
|
const model = models[i];
|
|
const hasNextModel = i < models.length - 1;
|
|
|
|
try {
|
|
const classification = await this.client.chat.completions.create({
|
|
model,
|
|
messages: [
|
|
{
|
|
role: "user",
|
|
content: `Classify this message into exactly one category. Only respond with the category name, nothing else.
|
|
|
|
Message: "${message}"
|
|
|
|
Categories:
|
|
- story: User wants a story, narrative, or creative writing
|
|
- snarky: User is being sarcastic or deserves a witty comeback
|
|
- insult: User is being rude or hostile, respond with brutal roasts (non-sexual)
|
|
- explicit: User wants adult/NSFW content
|
|
- helpful: User has a genuine question or needs actual help
|
|
|
|
Category:`,
|
|
},
|
|
],
|
|
max_tokens: 10,
|
|
temperature: 0.1,
|
|
});
|
|
|
|
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
|
|
|
|
if (STYLE_OPTIONS.includes(result)) {
|
|
logger.debug("Message classified", { style: result, model });
|
|
return result;
|
|
}
|
|
|
|
logger.debug("Classification returned invalid style, defaulting to snarky", { result, model });
|
|
return "snarky";
|
|
} catch (error) {
|
|
if (hasNextModel) {
|
|
logger.warn("Classification model failed, trying fallback", {
|
|
method: "classifyMessage",
|
|
model,
|
|
nextModel: models[i + 1],
|
|
status: this.getErrorStatus(error),
|
|
});
|
|
continue;
|
|
}
|
|
|
|
logger.error("Failed to classify message", {
|
|
method: "classifyMessage",
|
|
model,
|
|
messageLength: message.length,
|
|
attempts: models.length,
|
|
});
|
|
logger.error("Classification error details", error);
|
|
}
|
|
}
|
|
|
|
return "snarky";
|
|
}
|
|
|
|
/**
|
|
* Cheap binary classifier to detect if a message is directed at Joel
|
|
*/
|
|
async classifyJoelDirected(message: string): Promise<boolean> {
|
|
const models = this.getClassificationModelCandidates();
|
|
|
|
for (let i = 0; i < models.length; i++) {
|
|
const model = models[i];
|
|
const hasNextModel = i < models.length - 1;
|
|
|
|
try {
|
|
const classification = await this.client.chat.completions.create({
|
|
model,
|
|
messages: [
|
|
{
|
|
role: "user",
|
|
content: `Determine if this Discord message is directed at Joel (the bot), or talking about Joel in a way Joel should respond.
|
|
|
|
Only respond with one token: YES or NO.
|
|
|
|
Guidance:
|
|
- YES if the user is asking Joel a question, requesting Joel to do something, replying conversationally to Joel, or maybe discussing Joel as a participant.
|
|
- NO if it's general chat between humans, statements that do not involve Joel.
|
|
|
|
Message: "${message}"
|
|
|
|
Answer:`,
|
|
},
|
|
],
|
|
max_tokens: 3,
|
|
temperature: 0,
|
|
});
|
|
|
|
const result = classification.choices[0]?.message?.content?.trim().toUpperCase();
|
|
return result?.startsWith("YES") ?? false;
|
|
} catch (error) {
|
|
if (hasNextModel) {
|
|
logger.warn("Directed classification model failed, trying fallback", {
|
|
method: "classifyJoelDirected",
|
|
model,
|
|
nextModel: models[i + 1],
|
|
status: this.getErrorStatus(error),
|
|
});
|
|
continue;
|
|
}
|
|
|
|
logger.error("Failed to classify directed message", {
|
|
method: "classifyJoelDirected",
|
|
model,
|
|
messageLength: message.length,
|
|
attempts: models.length,
|
|
});
|
|
logger.error("Directed classification error details", error);
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
}
|