update openrouter

This commit is contained in:
2026-02-01 16:55:14 +01:00
parent 6dbcadcaee
commit e2f69e68cd
8 changed files with 199 additions and 75 deletions

View File

@@ -3,8 +3,8 @@
*/
import { createLogger } from "../../core/logger";
import { ReplicateProvider } from "./replicate";
import type { AiProvider, AiResponse } from "./types";
import { OpenRouterProvider } from "./openrouter";
import type { AiProvider, AiResponse, MessageStyle } from "./types";
const logger = createLogger("AI:Service");
@@ -12,7 +12,7 @@ export class AiService {
private provider: AiProvider;
constructor(provider?: AiProvider) {
this.provider = provider ?? new ReplicateProvider();
this.provider = provider ?? new OpenRouterProvider();
}
async health(): Promise<boolean> {
@@ -26,6 +26,17 @@ export class AiService {
logger.debug("Generating response", { promptLength: prompt.length });
return this.provider.ask({ prompt, systemPrompt });
}
/**
* Classify a message to determine the appropriate response style
*/
async classifyMessage(message: string): Promise<MessageStyle> {
if (this.provider.classifyMessage) {
return this.provider.classifyMessage(message);
}
// Default to snarky if provider doesn't support classification
return "snarky";
}
}
// Singleton instance
@@ -38,4 +49,4 @@ export function getAiService(): AiService {
return aiService;
}
export type { AiProvider, AiResponse } from "./types";
export type { AiProvider, AiResponse, MessageStyle } from "./types";

View File

@@ -0,0 +1,107 @@
/**
* OpenRouter AI provider implementation
*/
import OpenAI from "openai";
import { config } from "../../core/config";
import { createLogger } from "../../core/logger";
import type { AiProvider, AiResponse, AskOptions, MessageStyle } from "./types";
const logger = createLogger("AI:OpenRouter");
// Style classification options
const STYLE_OPTIONS: MessageStyle[] = ["story", "snarky", "insult", "explicit", "helpful"];
export class OpenRouterProvider implements AiProvider {
private client: OpenAI;
constructor() {
this.client = new OpenAI({
baseURL: "https://openrouter.ai/api/v1",
apiKey: config.ai.openRouterApiKey,
defaultHeaders: {
"HTTP-Referer": "https://github.com/crunk-bun",
"X-Title": "Joel Discord Bot",
},
});
}
async health(): Promise<boolean> {
try {
// Simple health check - verify we can list models
await this.client.models.list();
return true;
} catch (error) {
logger.error("Health check failed", error);
return false;
}
}
async ask(options: AskOptions): Promise<AiResponse> {
const { prompt, systemPrompt, maxTokens, temperature } = options;
try {
const completion = await this.client.chat.completions.create({
model: config.ai.model,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: prompt },
],
max_tokens: maxTokens ?? config.ai.maxTokens,
temperature: temperature ?? config.ai.temperature,
});
const text = completion.choices[0]?.message?.content ?? "";
// Discord message limit safety
return { text: text.slice(0, 1900) };
} catch (error: unknown) {
logger.error("Failed to generate response", error);
throw error;
}
}
/**
* Classify a message to determine the appropriate response style
*/
async classifyMessage(message: string): Promise<MessageStyle> {
try {
const classification = await this.client.chat.completions.create({
model: config.ai.classificationModel,
messages: [
{
role: "user",
content: `Classify this message into exactly one category. Only respond with the category name, nothing else.
Message: "${message}"
Categories:
- story: User wants a story, narrative, or creative writing
- snarky: User is being sarcastic or deserves a witty comeback
- insult: User is being rude or hostile, respond with brutal insults
- explicit: User wants adult/NSFW content
- helpful: User has a genuine question or needs actual help
Category:`,
},
],
max_tokens: 10,
temperature: 0.1,
});
const result = classification.choices[0]?.message?.content?.toLowerCase().trim() as MessageStyle;
// Validate the result is a valid style
if (STYLE_OPTIONS.includes(result)) {
logger.debug("Message classified", { style: result });
return result;
}
logger.debug("Classification returned invalid style, defaulting to snarky", { result });
return "snarky";
} catch (error) {
logger.error("Failed to classify message", error);
return "snarky"; // Default to snarky on error
}
}
}

View File

@@ -1,63 +0,0 @@
/**
* Replicate AI provider implementation
*/
import Replicate from "replicate";
import { config } from "../../core/config";
import { createLogger } from "../../core/logger";
import type { AiProvider, AiResponse, AskOptions } from "./types";
const logger = createLogger("AI:Replicate");
export class ReplicateProvider implements AiProvider {
private client: Replicate;
constructor() {
this.client = new Replicate({
auth: config.ai.replicateApiToken,
});
}
async health(): Promise<boolean> {
try {
// Simple health check - just verify we can create a client
return true;
} catch (error) {
logger.error("Health check failed", error);
return false;
}
}
async ask(options: AskOptions): Promise<AiResponse> {
const { prompt, systemPrompt, maxTokens, temperature } = options;
try {
const formattedPrompt = `<|im_start|>system
${systemPrompt}<|im_end|>
<|im_start|>user
${prompt}<|im_end|>
<|im_start|>assistant
`;
const input = {
prompt: formattedPrompt,
temperature: temperature ?? config.ai.temperature,
max_new_tokens: maxTokens ?? config.ai.maxTokens,
};
let output = "";
for await (const event of this.client.stream(config.ai.model as `${string}/${string}:${string}`, {
input,
})) {
output += event;
// Discord message limit safety
if (output.length >= 1900) break;
}
return { text: output.slice(0, 1900) };
} catch (error: unknown) {
logger.error("Failed to generate response", error);
throw error;
}
}
}

View File

@@ -7,6 +7,11 @@ export interface AiResponse {
text: string;
}
/**
* Message style classification options
*/
export type MessageStyle = "story" | "snarky" | "insult" | "explicit" | "helpful";
export interface AiProvider {
/**
* Generate a response to a prompt
@@ -17,6 +22,11 @@ export interface AiProvider {
* Check if the AI service is healthy
*/
health(): Promise<boolean>;
/**
* Classify a message to determine response style
*/
classifyMessage?(message: string): Promise<MessageStyle>;
}
export interface AskOptions {