Files
joel/src/services/ai/openrouter.ts
2026-03-12 22:13:12 +01:00

551 lines
16 KiB
TypeScript

/**
* OpenRouter AI provider implementation
*/
import OpenAI from "openai";
import type {
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool,
} from "openai/resources/chat/completions";
import { config } from "../../core/config";
import { createLogger } from "../../core/logger";
import type { AiProvider, AiResponse, AskOptions, AskWithToolsOptions, MessageStyle, TextStreamHandler } from "./types";
import { MEMORY_EXTRACTION_TOOLS, getToolsForContext, type ToolCall, type ToolContext } from "./tools";
import { executeTools } from "./tool-handlers";
const logger = createLogger("AI:OpenRouter");
// Style classification options
const STYLE_OPTIONS: MessageStyle[] = ["story", "snarky", "insult", "explicit", "helpful"];
// Maximum tool call iterations to prevent infinite loops
const MAX_TOOL_ITERATIONS = 5;
interface StreamedToolCall {
id: string;
type: "function";
function: {
name: string;
arguments: string;
};
}
interface StreamedCompletionResult {
text: string;
toolCalls: StreamedToolCall[];
}
export class OpenRouterProvider implements AiProvider {
private client: OpenAI;
constructor() {
this.client = new OpenAI({
baseURL: "https://openrouter.ai/api/v1",
apiKey: config.ai.openRouterApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/crunk-bun",
"X-Title": "Joel Discord Bot",
},
});
}
async health(): Promise<boolean> {
try {
// Simple health check - verify we can list models
await this.client.models.list();
return true;
} catch (error) {
logger.error("Health check failed", error);
return false;
}
}
private getClassificationModelCandidates(): string[] {
const models = [
config.ai.classificationModel,
...config.ai.classificationFallbackModels,
config.ai.model,
];
return Array.from(new Set(models.map((model) => model.trim()).filter((model) => model.length > 0)));
}
private getErrorStatus(error: unknown): number | undefined {
const err = error as {
status?: number;
code?: number | string;
error?: { code?: number | string };
};
const code = typeof err.code === "number"
? err.code
: typeof err.error?.code === "number"
? err.error.code
: undefined;
return err.status ?? code;
}
async ask(options: AskOptions): Promise<AiResponse> {
const { prompt, systemPrompt, maxTokens, temperature, onTextStream } = options;
const model = config.ai.model;
try {
if (onTextStream) {
const streamed = await this.streamChatCompletion({
model,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
],
max_tokens: maxTokens ?? config.ai.maxTokens,
temperature: temperature ?? config.ai.temperature,
}, onTextStream);
return { text: streamed.text };
}
const completion = await this.client.chat.completions.create({
model,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
],
max_tokens: maxTokens ?? config.ai.maxTokens,
temperature: temperature ?? config.ai.temperature,
});
const text = completion.choices[0]?.message?.content ?? "";
return { text };
} catch (error: unknown) {
logger.error("Failed to generate response (ask)", {
method: "ask",
model,
promptLength: prompt.length,
...(error instanceof Error ? {} : { rawError: error }),
});
logger.error("API error details", error);
throw error;
}
}
/**
* Generate a response with tool calling support
* The AI can call tools (like looking up memories) during response generation
*/
async askWithTools(options: AskWithToolsOptions): Promise<AiResponse> {
const { prompt, systemPrompt, context, maxTokens, temperature, onTextStream } = options;
const messages: ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
];
// Get the appropriate tools for this context (includes optional tools like GIF search)
const tools = getToolsForContext(context);
let iterations = 0;
while (iterations < MAX_TOOL_ITERATIONS) {
iterations++;
try {
if (onTextStream) {
const streamed = await this.streamChatCompletion({
model: config.ai.model,
messages,
tools,
tool_choice: "auto",
max_tokens: maxTokens ?? config.ai.maxTokens,
temperature: temperature ?? config.ai.temperature,
}, onTextStream);
if (streamed.toolCalls.length > 0) {
logger.debug("AI requested tool calls", {
count: streamed.toolCalls.length,
tools: streamed.toolCalls.map((tc) => tc.function.name),
});
messages.push({
role: "assistant",
content: streamed.text || null,
tool_calls: streamed.toolCalls,
});
await onTextStream("");
const toolCalls = this.parseToolCalls(streamed.toolCalls);
const results = await executeTools(toolCalls, context);
for (let i = 0; i < toolCalls.length; i++) {
messages.push({
role: "tool",
tool_call_id: toolCalls[i].id,
content: results[i].result,
});
}
continue;
}
logger.debug("AI response generated", {
iterations,
textLength: streamed.text.length,
streamed: true,
});
return { text: streamed.text };
}
const completion = await this.client.chat.completions.create({
model: config.ai.model,
messages,
tools,
tool_choice: "auto",
max_tokens: maxTokens ?? config.ai.maxTokens,
temperature: temperature ?? config.ai.temperature,
});
const choice = completion.choices[0];
const message = choice?.message;
if (!message) {
logger.warn("No message in completion");
return { text: "" };
}
// Check if the AI wants to call tools
if (message.tool_calls && message.tool_calls.length > 0) {
logger.debug("AI requested tool calls", {
count: message.tool_calls.length,
tools: message.tool_calls.map(tc => tc.function.name)
});
// Add the assistant's message with tool calls
messages.push(message);
// Parse and execute tool calls
const toolCalls: ToolCall[] = message.tool_calls.map((tc) => ({
id: tc.id,
name: tc.function.name,
arguments: JSON.parse(tc.function.arguments || "{}"),
}));
const results = await executeTools(toolCalls, context);
// Add tool results as messages
for (let i = 0; i < toolCalls.length; i++) {
messages.push({
role: "tool",
tool_call_id: toolCalls[i].id,
content: results[i].result,
});
}
// Continue the loop to get the AI's response after tool execution
continue;
}
// No tool calls - we have a final response
const text = message.content ?? "";
logger.debug("AI response generated", {
iterations,
textLength: text.length
});
return { text };
} catch (error: unknown) {
logger.error("Failed to generate response with tools (askWithTools)", {
method: "askWithTools",
model: config.ai.model,
iteration: iterations,
messageCount: messages.length,
toolCount: tools.length,
...(error instanceof Error ? {} : { rawError: error }),
});
logger.error("API error details", error);
throw error;
}
}
logger.warn("Max tool iterations reached");
return { text: "I got stuck in a loop thinking about that..." };
}
private async streamChatCompletion(
params: {
model: string;
messages: ChatCompletionMessageParam[];
tools?: ChatCompletionTool[];
tool_choice?: "auto" | "none";
max_tokens: number;
temperature: number;
},
onTextStream: TextStreamHandler,
): Promise<StreamedCompletionResult> {
const stream = await this.client.chat.completions.create({
...params,
stream: true,
});
let text = "";
const toolCalls = new Map<number, StreamedToolCall>();
for await (const chunk of stream) {
const choice = chunk.choices[0];
if (!choice) {
continue;
}
const delta = choice.delta;
const content = delta.content ?? "";
if (content) {
text += content;
await onTextStream(text);
}
for (const toolCallDelta of delta.tool_calls ?? []) {
const current = toolCalls.get(toolCallDelta.index) ?? {
id: "",
type: "function" as const,
function: {
name: "",
arguments: "",
},
};
if (toolCallDelta.id) {
current.id = toolCallDelta.id;
}
if (toolCallDelta.function?.name) {
current.function.name = toolCallDelta.function.name;
}
if (toolCallDelta.function?.arguments) {
current.function.arguments += toolCallDelta.function.arguments;
}
toolCalls.set(toolCallDelta.index, current);
}
}
return {
text,
toolCalls: Array.from(toolCalls.entries())
.sort((a, b) => a[0] - b[0])
.map(([, toolCall]) => toolCall),
};
}
private parseToolCalls(toolCalls: ChatCompletionMessageToolCall[]): ToolCall[] {
return toolCalls.map((toolCall) => {
try {
return {
id: toolCall.id,
name: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments || "{}"),
};
} catch (error) {
logger.error("Failed to parse streamed tool call arguments", {
toolName: toolCall.function.name,
toolCallId: toolCall.id,
arguments: toolCall.function.arguments,
error,
});
throw error;
}
});
}
/**
* Analyze a message to extract memorable information
*/
async extractMemories(
message: string,
authorName: string,
context: ToolContext
): Promise<void> {
const systemPrompt = `You are analyzing a Discord message to determine if it contains any memorable or useful information about the user "${authorName}".
Look for:
- Personal information (name, age, location, job, hobbies)
- Preferences (likes, dislikes, favorites)
- Embarrassing admissions or confessions
- Strong opinions or hot takes
- Achievements or accomplishments
- Relationships or social information
- Recurring patterns or habits
If you find something worth remembering, use the extract_memory tool. Only extract genuinely interesting or useful information - don't save trivial things.
The user's Discord ID is: ${context.userId}`;
try {
const completion = await this.client.chat.completions.create({
model: config.ai.model, // Use main model - needs tool support
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: `Analyze this message for memorable content:\n\n"${message}"` },
],
tools: MEMORY_EXTRACTION_TOOLS,
tool_choice: "auto",
max_tokens: 200,
temperature: 0.3,
});
const toolCalls = completion.choices[0]?.message?.tool_calls;
if (toolCalls && toolCalls.length > 0) {
const parsedCalls: ToolCall[] = toolCalls.map((tc) => ({
id: tc.id,
name: tc.function.name,
arguments: JSON.parse(tc.function.arguments || "{}"),
}));
await executeTools(parsedCalls, context);
logger.debug("Memory extraction complete", {
extracted: parsedCalls.length,
authorName
});
}
} catch (error) {
// Don't throw - memory extraction is non-critical
logger.error("Memory extraction failed", {
method: "extractMemories",
model: config.ai.model,
authorName,
messageLength: message.length,
});
logger.error("Memory extraction error details", error);
}
}
/**
* Classify a message to determine the appropriate response style
*/
async classifyMessage(message: string): Promise<MessageStyle> {
const models = this.getClassificationModelCandidates();
for (let i = 0; i < models.length; i++) {
const model = models[i];
const hasNextModel = i < models.length - 1;
try {
const classification = await this.client.chat.completions.create({
model,
messages: [
{
role: "user",
content: `Classify this message into exactly one category. Only respond with the category name, nothing else.
Message: "${message}"
Categories:
- story: User wants a story, narrative, or creative writing
- snarky: User is being sarcastic or deserves a witty comeback
- insult: User is being rude or hostile, respond with brutal roasts (non-sexual)
- explicit: User wants adult/NSFW content
- helpful: User has a genuine question or needs actual help
Category:`,
},
],
max_tokens: 10,
temperature: 0.1,
});
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
if (STYLE_OPTIONS.includes(result)) {
logger.debug("Message classified", { style: result, model });
return result;
}
logger.debug("Classification returned invalid style, defaulting to snarky", { result, model });
return "snarky";
} catch (error) {
if (hasNextModel) {
logger.warn("Classification model failed, trying fallback", {
method: "classifyMessage",
model,
nextModel: models[i + 1],
status: this.getErrorStatus(error),
});
continue;
}
logger.error("Failed to classify message", {
method: "classifyMessage",
model,
messageLength: message.length,
attempts: models.length,
});
logger.error("Classification error details", error);
}
}
return "snarky";
}
/**
* Cheap binary classifier to detect if a message is directed at Joel
*/
async classifyJoelDirected(message: string): Promise<boolean> {
const models = this.getClassificationModelCandidates();
for (let i = 0; i < models.length; i++) {
const model = models[i];
const hasNextModel = i < models.length - 1;
try {
const classification = await this.client.chat.completions.create({
model,
messages: [
{
role: "user",
content: `Determine if this Discord message is directed at Joel (the bot), or talking about Joel in a way Joel should respond.
Only respond with one token: YES or NO.
Guidance:
- YES if the user is asking Joel a question, requesting Joel to do something, replying conversationally to Joel, or maybe discussing Joel as a participant.
- NO if it's general chat between humans, statements that do not involve Joel.
Message: "${message}"
Answer:`,
},
],
max_tokens: 3,
temperature: 0,
});
const result = classification.choices[0]?.message?.content?.trim().toUpperCase();
return result?.startsWith("YES") ?? false;
} catch (error) {
if (hasNextModel) {
logger.warn("Directed classification model failed, trying fallback", {
method: "classifyJoelDirected",
model,
nextModel: models[i + 1],
status: this.getErrorStatus(error),
});
continue;
}
logger.error("Failed to classify directed message", {
method: "classifyJoelDirected",
model,
messageLength: message.length,
attempts: models.length,
});
logger.error("Directed classification error details", error);
}
}
return false;
}
}