diff --git a/.env.example b/.env.example index 5a06415..54ac2d3 100644 --- a/.env.example +++ b/.env.example @@ -6,6 +6,8 @@ BOT_OWNER_ID="" HF_TOKEN="" OPENAI_API_KEY="" OPENROUTER_API_KEY="" +AI_CLASSIFICATION_MODEL="google/gemma-3-12b-it:free" +AI_CLASSIFICATION_FALLBACK_MODELS="qwen/qwen-2.5-7b-instruct:free,mistralai/mistral-7b-instruct:free" REPLICATE_API_TOKEN="" ELEVENLABS_API_KEY="" ELEVENLABS_VOICE_ID="" diff --git a/README.md b/README.md index 2f4f4ff..522652f 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ src/ | `DISCORD_CLIENT_ID` | Discord application client ID | | `DISCORD_CLIENT_SECRET` | Discord application client secret | | `OPENROUTER_API_KEY` | OpenRouter API key for AI | +| `AI_CLASSIFICATION_FALLBACK_MODELS` | Comma-separated fallback model IDs for classification requests | | `KLIPY_API_KEY` | Klipy API key for GIF search (optional) | | `ELEVENLABS_API_KEY` | ElevenLabs API key for voiceover | | `ELEVENLABS_VOICE_ID` | Default ElevenLabs voice ID (optional) | diff --git a/src/core/config.ts b/src/core/config.ts index 08a531a..babf2d7 100644 --- a/src/core/config.ts +++ b/src/core/config.ts @@ -13,6 +13,7 @@ interface BotConfig { openRouterApiKey: string; model: string; classificationModel: string; + classificationFallbackModels: string[]; maxTokens: number; temperature: number; }; @@ -75,6 +76,18 @@ function getBooleanEnvOrDefault(key: string, defaultValue: boolean): boolean { return normalized === "1" || normalized === "true" || normalized === "yes"; } +function getCsvEnvOrDefault(key: string, defaultValues: string[]): string[] { + const raw = Bun.env[key]; + if (!raw) { + return defaultValues; + } + + return raw + .split(",") + .map((value) => value.trim()) + .filter((value) => value.length > 0); +} + export const config: BotConfig = { discord: { token: getEnvOrThrow("DISCORD_TOKEN"), @@ -91,6 +104,10 @@ export const config: BotConfig = { "AI_CLASSIFICATION_MODEL", "google/gemma-3-12b-it:free" ), + classificationFallbackModels: getCsvEnvOrDefault("AI_CLASSIFICATION_FALLBACK_MODELS", [ + "qwen/qwen-2.5-7b-instruct:free", + "mistralai/mistral-7b-instruct:free", + ]), maxTokens: parseInt(getEnvOrDefault("AI_MAX_TOKENS", "500")), temperature: parseFloat(getEnvOrDefault("AI_TEMPERATURE", "1.2")), }, diff --git a/src/services/ai/openrouter.ts b/src/services/ai/openrouter.ts index 382d7e3..e15bc11 100644 --- a/src/services/ai/openrouter.ts +++ b/src/services/ai/openrouter.ts @@ -43,6 +43,32 @@ export class OpenRouterProvider implements AiProvider { } } + private getClassificationModelCandidates(): string[] { + const models = [ + config.ai.classificationModel, + ...config.ai.classificationFallbackModels, + config.ai.model, + ]; + + return Array.from(new Set(models.map((model) => model.trim()).filter((model) => model.length > 0))); + } + + private getErrorStatus(error: unknown): number | undefined { + const err = error as { + status?: number; + code?: number | string; + error?: { code?: number | string }; + }; + + const code = typeof err.code === "number" + ? err.code + : typeof err.error?.code === "number" + ? err.error.code + : undefined; + + return err.status ?? code; + } + async ask(options: AskOptions): Promise { const { prompt, systemPrompt, maxTokens, temperature } = options; const model = config.ai.model; @@ -237,13 +263,19 @@ The user's Discord ID is: ${context.userId}`; * Classify a message to determine the appropriate response style */ async classifyMessage(message: string): Promise { - try { - const classification = await this.client.chat.completions.create({ - model: config.ai.classificationModel, - messages: [ - { - role: "user", - content: `Classify this message into exactly one category. Only respond with the category name, nothing else. + const models = this.getClassificationModelCandidates(); + + for (let i = 0; i < models.length; i++) { + const model = models[i]; + const hasNextModel = i < models.length - 1; + + try { + const classification = await this.client.chat.completions.create({ + model, + messages: [ + { + role: "user", + content: `Classify this message into exactly one category. Only respond with the category name, nothing else. Message: "${message}" @@ -255,44 +287,62 @@ Categories: - helpful: User has a genuine question or needs actual help Category:`, - }, - ], - max_tokens: 10, - temperature: 0.1, - }); + }, + ], + max_tokens: 10, + temperature: 0.1, + }); - const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle; - - // Validate the result is a valid style - if (STYLE_OPTIONS.includes(result)) { - logger.debug("Message classified", { style: result }); - return result; + const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle; + + if (STYLE_OPTIONS.includes(result)) { + logger.debug("Message classified", { style: result, model }); + return result; + } + + logger.debug("Classification returned invalid style, defaulting to snarky", { result, model }); + return "snarky"; + } catch (error) { + if (hasNextModel) { + logger.warn("Classification model failed, trying fallback", { + method: "classifyMessage", + model, + nextModel: models[i + 1], + status: this.getErrorStatus(error), + }); + continue; + } + + logger.error("Failed to classify message", { + method: "classifyMessage", + model, + messageLength: message.length, + attempts: models.length, + }); + logger.error("Classification error details", error); } - - logger.debug("Classification returned invalid style, defaulting to snarky", { result }); - return "snarky"; - } catch (error) { - logger.error("Failed to classify message", { - method: "classifyMessage", - model: config.ai.classificationModel, - messageLength: message.length, - }); - logger.error("Classification error details", error); - return "snarky"; // Default to snarky on error } + + return "snarky"; } /** * Cheap binary classifier to detect if a message is directed at Joel */ async classifyJoelDirected(message: string): Promise { - try { - const classification = await this.client.chat.completions.create({ - model: config.ai.classificationModel, - messages: [ - { - role: "user", - content: `Determine if this Discord message is directed at Joel (the bot), or talking about Joel in a way Joel should respond. + const models = this.getClassificationModelCandidates(); + + for (let i = 0; i < models.length; i++) { + const model = models[i]; + const hasNextModel = i < models.length - 1; + + try { + const classification = await this.client.chat.completions.create({ + model, + messages: [ + { + role: "user", + content: `Determine if this Discord message is directed at Joel (the bot), or talking about Joel in a way Joel should respond. Only respond with one token: YES or NO. @@ -303,22 +353,35 @@ Guidance: Message: "${message}" Answer:`, - }, - ], - max_tokens: 3, - temperature: 0, - }); + }, + ], + max_tokens: 3, + temperature: 0, + }); - const result = classification.choices[0]?.message?.content?.trim().toUpperCase(); - return result?.startsWith("YES") ?? false; - } catch (error) { - logger.error("Failed to classify directed message", { - method: "classifyJoelDirected", - model: config.ai.classificationModel, - messageLength: message.length, - }); - logger.error("Directed classification error details", error); - return false; + const result = classification.choices[0]?.message?.content?.trim().toUpperCase(); + return result?.startsWith("YES") ?? false; + } catch (error) { + if (hasNextModel) { + logger.warn("Directed classification model failed, trying fallback", { + method: "classifyJoelDirected", + model, + nextModel: models[i + 1], + status: this.getErrorStatus(error), + }); + continue; + } + + logger.error("Failed to classify directed message", { + method: "classifyJoelDirected", + model, + messageLength: message.length, + attempts: models.length, + }); + logger.error("Directed classification error details", error); + } } + + return false; } }