feat: add fallback

This commit is contained in:
eric
2026-02-23 15:16:15 +01:00
parent a4650c14ae
commit e26d665bdf
4 changed files with 134 additions and 51 deletions

View File

@@ -6,6 +6,8 @@ BOT_OWNER_ID=""
HF_TOKEN="" HF_TOKEN=""
OPENAI_API_KEY="" OPENAI_API_KEY=""
OPENROUTER_API_KEY="" OPENROUTER_API_KEY=""
AI_CLASSIFICATION_MODEL="google/gemma-3-12b-it:free"
AI_CLASSIFICATION_FALLBACK_MODELS="qwen/qwen-2.5-7b-instruct:free,mistralai/mistral-7b-instruct:free"
REPLICATE_API_TOKEN="" REPLICATE_API_TOKEN=""
ELEVENLABS_API_KEY="" ELEVENLABS_API_KEY=""
ELEVENLABS_VOICE_ID="" ELEVENLABS_VOICE_ID=""

View File

@@ -42,6 +42,7 @@ src/
| `DISCORD_CLIENT_ID` | Discord application client ID | | `DISCORD_CLIENT_ID` | Discord application client ID |
| `DISCORD_CLIENT_SECRET` | Discord application client secret | | `DISCORD_CLIENT_SECRET` | Discord application client secret |
| `OPENROUTER_API_KEY` | OpenRouter API key for AI | | `OPENROUTER_API_KEY` | OpenRouter API key for AI |
| `AI_CLASSIFICATION_FALLBACK_MODELS` | Comma-separated fallback model IDs for classification requests |
| `KLIPY_API_KEY` | Klipy API key for GIF search (optional) | | `KLIPY_API_KEY` | Klipy API key for GIF search (optional) |
| `ELEVENLABS_API_KEY` | ElevenLabs API key for voiceover | | `ELEVENLABS_API_KEY` | ElevenLabs API key for voiceover |
| `ELEVENLABS_VOICE_ID` | Default ElevenLabs voice ID (optional) | | `ELEVENLABS_VOICE_ID` | Default ElevenLabs voice ID (optional) |

View File

@@ -13,6 +13,7 @@ interface BotConfig {
openRouterApiKey: string; openRouterApiKey: string;
model: string; model: string;
classificationModel: string; classificationModel: string;
classificationFallbackModels: string[];
maxTokens: number; maxTokens: number;
temperature: number; temperature: number;
}; };
@@ -75,6 +76,18 @@ function getBooleanEnvOrDefault(key: string, defaultValue: boolean): boolean {
return normalized === "1" || normalized === "true" || normalized === "yes"; return normalized === "1" || normalized === "true" || normalized === "yes";
} }
function getCsvEnvOrDefault(key: string, defaultValues: string[]): string[] {
const raw = Bun.env[key];
if (!raw) {
return defaultValues;
}
return raw
.split(",")
.map((value) => value.trim())
.filter((value) => value.length > 0);
}
export const config: BotConfig = { export const config: BotConfig = {
discord: { discord: {
token: getEnvOrThrow("DISCORD_TOKEN"), token: getEnvOrThrow("DISCORD_TOKEN"),
@@ -91,6 +104,10 @@ export const config: BotConfig = {
"AI_CLASSIFICATION_MODEL", "AI_CLASSIFICATION_MODEL",
"google/gemma-3-12b-it:free" "google/gemma-3-12b-it:free"
), ),
classificationFallbackModels: getCsvEnvOrDefault("AI_CLASSIFICATION_FALLBACK_MODELS", [
"qwen/qwen-2.5-7b-instruct:free",
"mistralai/mistral-7b-instruct:free",
]),
maxTokens: parseInt(getEnvOrDefault("AI_MAX_TOKENS", "500")), maxTokens: parseInt(getEnvOrDefault("AI_MAX_TOKENS", "500")),
temperature: parseFloat(getEnvOrDefault("AI_TEMPERATURE", "1.2")), temperature: parseFloat(getEnvOrDefault("AI_TEMPERATURE", "1.2")),
}, },

View File

@@ -43,6 +43,32 @@ export class OpenRouterProvider implements AiProvider {
} }
} }
private getClassificationModelCandidates(): string[] {
const models = [
config.ai.classificationModel,
...config.ai.classificationFallbackModels,
config.ai.model,
];
return Array.from(new Set(models.map((model) => model.trim()).filter((model) => model.length > 0)));
}
private getErrorStatus(error: unknown): number | undefined {
const err = error as {
status?: number;
code?: number | string;
error?: { code?: number | string };
};
const code = typeof err.code === "number"
? err.code
: typeof err.error?.code === "number"
? err.error.code
: undefined;
return err.status ?? code;
}
async ask(options: AskOptions): Promise<AiResponse> { async ask(options: AskOptions): Promise<AiResponse> {
const { prompt, systemPrompt, maxTokens, temperature } = options; const { prompt, systemPrompt, maxTokens, temperature } = options;
const model = config.ai.model; const model = config.ai.model;
@@ -237,9 +263,15 @@ The user's Discord ID is: ${context.userId}`;
* Classify a message to determine the appropriate response style * Classify a message to determine the appropriate response style
*/ */
async classifyMessage(message: string): Promise<MessageStyle> { async classifyMessage(message: string): Promise<MessageStyle> {
const models = this.getClassificationModelCandidates();
for (let i = 0; i < models.length; i++) {
const model = models[i];
const hasNextModel = i < models.length - 1;
try { try {
const classification = await this.client.chat.completions.create({ const classification = await this.client.chat.completions.create({
model: config.ai.classificationModel, model,
messages: [ messages: [
{ {
role: "user", role: "user",
@@ -263,32 +295,50 @@ Category:`,
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle; const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
// Validate the result is a valid style
if (STYLE_OPTIONS.includes(result)) { if (STYLE_OPTIONS.includes(result)) {
logger.debug("Message classified", { style: result }); logger.debug("Message classified", { style: result, model });
return result; return result;
} }
logger.debug("Classification returned invalid style, defaulting to snarky", { result }); logger.debug("Classification returned invalid style, defaulting to snarky", { result, model });
return "snarky"; return "snarky";
} catch (error) { } catch (error) {
if (hasNextModel) {
logger.warn("Classification model failed, trying fallback", {
method: "classifyMessage",
model,
nextModel: models[i + 1],
status: this.getErrorStatus(error),
});
continue;
}
logger.error("Failed to classify message", { logger.error("Failed to classify message", {
method: "classifyMessage", method: "classifyMessage",
model: config.ai.classificationModel, model,
messageLength: message.length, messageLength: message.length,
attempts: models.length,
}); });
logger.error("Classification error details", error); logger.error("Classification error details", error);
return "snarky"; // Default to snarky on error
} }
} }
return "snarky";
}
/** /**
* Cheap binary classifier to detect if a message is directed at Joel * Cheap binary classifier to detect if a message is directed at Joel
*/ */
async classifyJoelDirected(message: string): Promise<boolean> { async classifyJoelDirected(message: string): Promise<boolean> {
const models = this.getClassificationModelCandidates();
for (let i = 0; i < models.length; i++) {
const model = models[i];
const hasNextModel = i < models.length - 1;
try { try {
const classification = await this.client.chat.completions.create({ const classification = await this.client.chat.completions.create({
model: config.ai.classificationModel, model,
messages: [ messages: [
{ {
role: "user", role: "user",
@@ -312,13 +362,26 @@ Answer:`,
const result = classification.choices[0]?.message?.content?.trim().toUpperCase(); const result = classification.choices[0]?.message?.content?.trim().toUpperCase();
return result?.startsWith("YES") ?? false; return result?.startsWith("YES") ?? false;
} catch (error) { } catch (error) {
if (hasNextModel) {
logger.warn("Directed classification model failed, trying fallback", {
method: "classifyJoelDirected",
model,
nextModel: models[i + 1],
status: this.getErrorStatus(error),
});
continue;
}
logger.error("Failed to classify directed message", { logger.error("Failed to classify directed message", {
method: "classifyJoelDirected", method: "classifyJoelDirected",
model: config.ai.classificationModel, model,
messageLength: message.length, messageLength: message.length,
attempts: models.length,
}); });
logger.error("Directed classification error details", error); logger.error("Directed classification error details", error);
}
}
return false; return false;
} }
}
} }