feat: add fallback
This commit is contained in:
@@ -6,6 +6,8 @@ BOT_OWNER_ID=""
|
|||||||
HF_TOKEN=""
|
HF_TOKEN=""
|
||||||
OPENAI_API_KEY=""
|
OPENAI_API_KEY=""
|
||||||
OPENROUTER_API_KEY=""
|
OPENROUTER_API_KEY=""
|
||||||
|
AI_CLASSIFICATION_MODEL="google/gemma-3-12b-it:free"
|
||||||
|
AI_CLASSIFICATION_FALLBACK_MODELS="qwen/qwen-2.5-7b-instruct:free,mistralai/mistral-7b-instruct:free"
|
||||||
REPLICATE_API_TOKEN=""
|
REPLICATE_API_TOKEN=""
|
||||||
ELEVENLABS_API_KEY=""
|
ELEVENLABS_API_KEY=""
|
||||||
ELEVENLABS_VOICE_ID=""
|
ELEVENLABS_VOICE_ID=""
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ src/
|
|||||||
| `DISCORD_CLIENT_ID` | Discord application client ID |
|
| `DISCORD_CLIENT_ID` | Discord application client ID |
|
||||||
| `DISCORD_CLIENT_SECRET` | Discord application client secret |
|
| `DISCORD_CLIENT_SECRET` | Discord application client secret |
|
||||||
| `OPENROUTER_API_KEY` | OpenRouter API key for AI |
|
| `OPENROUTER_API_KEY` | OpenRouter API key for AI |
|
||||||
|
| `AI_CLASSIFICATION_FALLBACK_MODELS` | Comma-separated fallback model IDs for classification requests |
|
||||||
| `KLIPY_API_KEY` | Klipy API key for GIF search (optional) |
|
| `KLIPY_API_KEY` | Klipy API key for GIF search (optional) |
|
||||||
| `ELEVENLABS_API_KEY` | ElevenLabs API key for voiceover |
|
| `ELEVENLABS_API_KEY` | ElevenLabs API key for voiceover |
|
||||||
| `ELEVENLABS_VOICE_ID` | Default ElevenLabs voice ID (optional) |
|
| `ELEVENLABS_VOICE_ID` | Default ElevenLabs voice ID (optional) |
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ interface BotConfig {
|
|||||||
openRouterApiKey: string;
|
openRouterApiKey: string;
|
||||||
model: string;
|
model: string;
|
||||||
classificationModel: string;
|
classificationModel: string;
|
||||||
|
classificationFallbackModels: string[];
|
||||||
maxTokens: number;
|
maxTokens: number;
|
||||||
temperature: number;
|
temperature: number;
|
||||||
};
|
};
|
||||||
@@ -75,6 +76,18 @@ function getBooleanEnvOrDefault(key: string, defaultValue: boolean): boolean {
|
|||||||
return normalized === "1" || normalized === "true" || normalized === "yes";
|
return normalized === "1" || normalized === "true" || normalized === "yes";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getCsvEnvOrDefault(key: string, defaultValues: string[]): string[] {
|
||||||
|
const raw = Bun.env[key];
|
||||||
|
if (!raw) {
|
||||||
|
return defaultValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw
|
||||||
|
.split(",")
|
||||||
|
.map((value) => value.trim())
|
||||||
|
.filter((value) => value.length > 0);
|
||||||
|
}
|
||||||
|
|
||||||
export const config: BotConfig = {
|
export const config: BotConfig = {
|
||||||
discord: {
|
discord: {
|
||||||
token: getEnvOrThrow("DISCORD_TOKEN"),
|
token: getEnvOrThrow("DISCORD_TOKEN"),
|
||||||
@@ -91,6 +104,10 @@ export const config: BotConfig = {
|
|||||||
"AI_CLASSIFICATION_MODEL",
|
"AI_CLASSIFICATION_MODEL",
|
||||||
"google/gemma-3-12b-it:free"
|
"google/gemma-3-12b-it:free"
|
||||||
),
|
),
|
||||||
|
classificationFallbackModels: getCsvEnvOrDefault("AI_CLASSIFICATION_FALLBACK_MODELS", [
|
||||||
|
"qwen/qwen-2.5-7b-instruct:free",
|
||||||
|
"mistralai/mistral-7b-instruct:free",
|
||||||
|
]),
|
||||||
maxTokens: parseInt(getEnvOrDefault("AI_MAX_TOKENS", "500")),
|
maxTokens: parseInt(getEnvOrDefault("AI_MAX_TOKENS", "500")),
|
||||||
temperature: parseFloat(getEnvOrDefault("AI_TEMPERATURE", "1.2")),
|
temperature: parseFloat(getEnvOrDefault("AI_TEMPERATURE", "1.2")),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -43,6 +43,32 @@ export class OpenRouterProvider implements AiProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getClassificationModelCandidates(): string[] {
|
||||||
|
const models = [
|
||||||
|
config.ai.classificationModel,
|
||||||
|
...config.ai.classificationFallbackModels,
|
||||||
|
config.ai.model,
|
||||||
|
];
|
||||||
|
|
||||||
|
return Array.from(new Set(models.map((model) => model.trim()).filter((model) => model.length > 0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private getErrorStatus(error: unknown): number | undefined {
|
||||||
|
const err = error as {
|
||||||
|
status?: number;
|
||||||
|
code?: number | string;
|
||||||
|
error?: { code?: number | string };
|
||||||
|
};
|
||||||
|
|
||||||
|
const code = typeof err.code === "number"
|
||||||
|
? err.code
|
||||||
|
: typeof err.error?.code === "number"
|
||||||
|
? err.error.code
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return err.status ?? code;
|
||||||
|
}
|
||||||
|
|
||||||
async ask(options: AskOptions): Promise<AiResponse> {
|
async ask(options: AskOptions): Promise<AiResponse> {
|
||||||
const { prompt, systemPrompt, maxTokens, temperature } = options;
|
const { prompt, systemPrompt, maxTokens, temperature } = options;
|
||||||
const model = config.ai.model;
|
const model = config.ai.model;
|
||||||
@@ -237,9 +263,15 @@ The user's Discord ID is: ${context.userId}`;
|
|||||||
* Classify a message to determine the appropriate response style
|
* Classify a message to determine the appropriate response style
|
||||||
*/
|
*/
|
||||||
async classifyMessage(message: string): Promise<MessageStyle> {
|
async classifyMessage(message: string): Promise<MessageStyle> {
|
||||||
|
const models = this.getClassificationModelCandidates();
|
||||||
|
|
||||||
|
for (let i = 0; i < models.length; i++) {
|
||||||
|
const model = models[i];
|
||||||
|
const hasNextModel = i < models.length - 1;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const classification = await this.client.chat.completions.create({
|
const classification = await this.client.chat.completions.create({
|
||||||
model: config.ai.classificationModel,
|
model,
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
@@ -263,32 +295,50 @@ Category:`,
|
|||||||
|
|
||||||
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
|
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
|
||||||
|
|
||||||
// Validate the result is a valid style
|
|
||||||
if (STYLE_OPTIONS.includes(result)) {
|
if (STYLE_OPTIONS.includes(result)) {
|
||||||
logger.debug("Message classified", { style: result });
|
logger.debug("Message classified", { style: result, model });
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug("Classification returned invalid style, defaulting to snarky", { result });
|
logger.debug("Classification returned invalid style, defaulting to snarky", { result, model });
|
||||||
return "snarky";
|
return "snarky";
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (hasNextModel) {
|
||||||
|
logger.warn("Classification model failed, trying fallback", {
|
||||||
|
method: "classifyMessage",
|
||||||
|
model,
|
||||||
|
nextModel: models[i + 1],
|
||||||
|
status: this.getErrorStatus(error),
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
logger.error("Failed to classify message", {
|
logger.error("Failed to classify message", {
|
||||||
method: "classifyMessage",
|
method: "classifyMessage",
|
||||||
model: config.ai.classificationModel,
|
model,
|
||||||
messageLength: message.length,
|
messageLength: message.length,
|
||||||
|
attempts: models.length,
|
||||||
});
|
});
|
||||||
logger.error("Classification error details", error);
|
logger.error("Classification error details", error);
|
||||||
return "snarky"; // Default to snarky on error
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return "snarky";
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cheap binary classifier to detect if a message is directed at Joel
|
* Cheap binary classifier to detect if a message is directed at Joel
|
||||||
*/
|
*/
|
||||||
async classifyJoelDirected(message: string): Promise<boolean> {
|
async classifyJoelDirected(message: string): Promise<boolean> {
|
||||||
|
const models = this.getClassificationModelCandidates();
|
||||||
|
|
||||||
|
for (let i = 0; i < models.length; i++) {
|
||||||
|
const model = models[i];
|
||||||
|
const hasNextModel = i < models.length - 1;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const classification = await this.client.chat.completions.create({
|
const classification = await this.client.chat.completions.create({
|
||||||
model: config.ai.classificationModel,
|
model,
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
@@ -312,13 +362,26 @@ Answer:`,
|
|||||||
const result = classification.choices[0]?.message?.content?.trim().toUpperCase();
|
const result = classification.choices[0]?.message?.content?.trim().toUpperCase();
|
||||||
return result?.startsWith("YES") ?? false;
|
return result?.startsWith("YES") ?? false;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (hasNextModel) {
|
||||||
|
logger.warn("Directed classification model failed, trying fallback", {
|
||||||
|
method: "classifyJoelDirected",
|
||||||
|
model,
|
||||||
|
nextModel: models[i + 1],
|
||||||
|
status: this.getErrorStatus(error),
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
logger.error("Failed to classify directed message", {
|
logger.error("Failed to classify directed message", {
|
||||||
method: "classifyJoelDirected",
|
method: "classifyJoelDirected",
|
||||||
model: config.ai.classificationModel,
|
model,
|
||||||
messageLength: message.length,
|
messageLength: message.length,
|
||||||
|
attempts: models.length,
|
||||||
});
|
});
|
||||||
logger.error("Directed classification error details", error);
|
logger.error("Directed classification error details", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user