You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

379 lines
11 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const { v4: uuidv4 } = require("uuid");
  3. const {
  4. writeResponseChunk,
  5. clientAbortedHandler,
  6. formatChatHistory,
  7. } = require("../../helpers/chat/responses");
  8. const fs = require("fs");
  9. const path = require("path");
  10. const { safeJsonParse } = require("../../http");
  11. const {
  12. LLMPerformanceMonitor,
  13. } = require("../../helpers/chat/LLMPerformanceMonitor");
  14. const cacheFolder = path.resolve(
  15. process.env.STORAGE_DIR
  16. ? path.resolve(process.env.STORAGE_DIR, "models", "apipie")
  17. : path.resolve(__dirname, `../../../storage/models/apipie`)
  18. );
  19. class ApiPieLLM {
  20. constructor(embedder = null, modelPreference = null) {
  21. if (!process.env.APIPIE_LLM_API_KEY)
  22. throw new Error("No ApiPie LLM API key was set.");
  23. const { OpenAI: OpenAIApi } = require("openai");
  24. this.basePath = "https://apipie.ai/v1";
  25. this.openai = new OpenAIApi({
  26. baseURL: this.basePath,
  27. apiKey: process.env.APIPIE_LLM_API_KEY ?? null,
  28. });
  29. this.model =
  30. modelPreference ||
  31. process.env.APIPIE_LLM_MODEL_PREF ||
  32. "openrouter/mistral-7b-instruct";
  33. this.limits = {
  34. history: this.promptWindowLimit() * 0.15,
  35. system: this.promptWindowLimit() * 0.15,
  36. user: this.promptWindowLimit() * 0.7,
  37. };
  38. this.embedder = embedder ?? new NativeEmbedder();
  39. this.defaultTemp = 0.7;
  40. if (!fs.existsSync(cacheFolder))
  41. fs.mkdirSync(cacheFolder, { recursive: true });
  42. this.cacheModelPath = path.resolve(cacheFolder, "models.json");
  43. this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
  44. }
  45. log(text, ...args) {
  46. console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
  47. }
  48. // This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
  49. // from the current date. If it is, then we will refetch the API so that all the models are up
  50. // to date.
  51. #cacheIsStale() {
  52. const MAX_STALE = 6.048e8; // 1 Week in MS
  53. if (!fs.existsSync(this.cacheAtPath)) return true;
  54. const now = Number(new Date());
  55. const timestampMs = Number(fs.readFileSync(this.cacheAtPath));
  56. return now - timestampMs > MAX_STALE;
  57. }
  58. // This function fetches the models from the ApiPie API and caches them locally.
  59. // We do this because the ApiPie API has a lot of models, and we need to get the proper token context window
  60. // for each model and this is a constructor property - so we can really only get it if this cache exists.
  61. // We used to have this as a chore, but given there is an API to get the info - this makes little sense.
  62. // This might slow down the first request, but we need the proper token context window
  63. // for each model and this is a constructor property - so we can really only get it if this cache exists.
  64. async #syncModels() {
  65. if (fs.existsSync(this.cacheModelPath) && !this.#cacheIsStale())
  66. return false;
  67. this.log("Model cache is not present or stale. Fetching from ApiPie API.");
  68. await fetchApiPieModels();
  69. return;
  70. }
  71. #appendContext(contextTexts = []) {
  72. if (!contextTexts || !contextTexts.length) return "";
  73. return (
  74. "\nContext:\n" +
  75. contextTexts
  76. .map((text, i) => {
  77. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  78. })
  79. .join("")
  80. );
  81. }
  82. models() {
  83. if (!fs.existsSync(this.cacheModelPath)) return {};
  84. return safeJsonParse(
  85. fs.readFileSync(this.cacheModelPath, { encoding: "utf-8" }),
  86. {}
  87. );
  88. }
  89. chatModels() {
  90. const allModels = this.models();
  91. return Object.entries(allModels).reduce(
  92. (chatModels, [modelId, modelInfo]) => {
  93. // Filter for chat models
  94. if (
  95. modelInfo.subtype &&
  96. (modelInfo.subtype.includes("chat") ||
  97. modelInfo.subtype.includes("chatx"))
  98. ) {
  99. chatModels[modelId] = modelInfo;
  100. }
  101. return chatModels;
  102. },
  103. {}
  104. );
  105. }
  106. streamingEnabled() {
  107. return "streamGetChatCompletion" in this;
  108. }
  109. static promptWindowLimit(modelName) {
  110. const cacheModelPath = path.resolve(cacheFolder, "models.json");
  111. const availableModels = fs.existsSync(cacheModelPath)
  112. ? safeJsonParse(
  113. fs.readFileSync(cacheModelPath, { encoding: "utf-8" }),
  114. {}
  115. )
  116. : {};
  117. return availableModels[modelName]?.maxLength || 4096;
  118. }
  119. promptWindowLimit() {
  120. const availableModels = this.chatModels();
  121. return availableModels[this.model]?.maxLength || 4096;
  122. }
  123. async isValidChatCompletionModel(model = "") {
  124. await this.#syncModels();
  125. const availableModels = this.chatModels();
  126. return availableModels.hasOwnProperty(model);
  127. }
  128. /**
  129. * Generates appropriate content array for a message + attachments.
  130. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  131. * @returns {string|object[]}
  132. */
  133. #generateContent({ userPrompt, attachments = [] }) {
  134. if (!attachments.length) {
  135. return userPrompt;
  136. }
  137. const content = [{ type: "text", text: userPrompt }];
  138. for (let attachment of attachments) {
  139. content.push({
  140. type: "image_url",
  141. image_url: {
  142. url: attachment.contentString,
  143. detail: "auto",
  144. },
  145. });
  146. }
  147. return content.flat();
  148. }
  149. constructPrompt({
  150. systemPrompt = "",
  151. contextTexts = [],
  152. chatHistory = [],
  153. userPrompt = "",
  154. attachments = [],
  155. }) {
  156. const prompt = {
  157. role: "system",
  158. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  159. };
  160. return [
  161. prompt,
  162. ...formatChatHistory(chatHistory, this.#generateContent),
  163. {
  164. role: "user",
  165. content: this.#generateContent({ userPrompt, attachments }),
  166. },
  167. ];
  168. }
  169. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  170. if (!(await this.isValidChatCompletionModel(this.model)))
  171. throw new Error(
  172. `ApiPie chat: ${this.model} is not valid for chat completion!`
  173. );
  174. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  175. this.openai.chat.completions
  176. .create({
  177. model: this.model,
  178. messages,
  179. temperature,
  180. })
  181. .catch((e) => {
  182. throw new Error(e.message);
  183. })
  184. );
  185. if (
  186. !result.output.hasOwnProperty("choices") ||
  187. result.output.choices.length === 0
  188. )
  189. return null;
  190. return {
  191. textResponse: result.output.choices[0].message.content,
  192. metrics: {
  193. prompt_tokens: result.output.usage?.prompt_tokens || 0,
  194. completion_tokens: result.output.usage?.completion_tokens || 0,
  195. total_tokens: result.output.usage?.total_tokens || 0,
  196. outputTps:
  197. (result.output.usage?.completion_tokens || 0) / result.duration,
  198. duration: result.duration,
  199. },
  200. };
  201. }
  202. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  203. if (!(await this.isValidChatCompletionModel(this.model)))
  204. throw new Error(
  205. `ApiPie chat: ${this.model} is not valid for chat completion!`
  206. );
  207. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  208. this.openai.chat.completions.create({
  209. model: this.model,
  210. stream: true,
  211. messages,
  212. temperature,
  213. }),
  214. messages
  215. );
  216. return measuredStreamRequest;
  217. }
  218. handleStream(response, stream, responseProps) {
  219. const { uuid = uuidv4(), sources = [] } = responseProps;
  220. return new Promise(async (resolve) => {
  221. let fullText = "";
  222. // Establish listener to early-abort a streaming response
  223. // in case things go sideways or the user does not like the response.
  224. // We preserve the generated text but continue as if chat was completed
  225. // to preserve previously generated content.
  226. const handleAbort = () => {
  227. stream?.endMeasurement({
  228. completion_tokens: LLMPerformanceMonitor.countTokens(fullText),
  229. });
  230. clientAbortedHandler(resolve, fullText);
  231. };
  232. response.on("close", handleAbort);
  233. try {
  234. for await (const chunk of stream) {
  235. const message = chunk?.choices?.[0];
  236. const token = message?.delta?.content;
  237. if (token) {
  238. fullText += token;
  239. writeResponseChunk(response, {
  240. uuid,
  241. sources: [],
  242. type: "textResponseChunk",
  243. textResponse: token,
  244. close: false,
  245. error: false,
  246. });
  247. }
  248. if (message === undefined || message.finish_reason !== null) {
  249. writeResponseChunk(response, {
  250. uuid,
  251. sources,
  252. type: "textResponseChunk",
  253. textResponse: "",
  254. close: true,
  255. error: false,
  256. });
  257. response.removeListener("close", handleAbort);
  258. stream?.endMeasurement({
  259. completion_tokens: LLMPerformanceMonitor.countTokens(fullText),
  260. });
  261. resolve(fullText);
  262. }
  263. }
  264. } catch (e) {
  265. writeResponseChunk(response, {
  266. uuid,
  267. sources,
  268. type: "abort",
  269. textResponse: null,
  270. close: true,
  271. error: e.message,
  272. });
  273. response.removeListener("close", handleAbort);
  274. stream?.endMeasurement({
  275. completion_tokens: LLMPerformanceMonitor.countTokens(fullText),
  276. });
  277. resolve(fullText);
  278. }
  279. });
  280. }
  281. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  282. async embedTextInput(textInput) {
  283. return await this.embedder.embedTextInput(textInput);
  284. }
  285. async embedChunks(textChunks = []) {
  286. return await this.embedder.embedChunks(textChunks);
  287. }
  288. async compressMessages(promptArgs = {}, rawHistory = []) {
  289. const { messageArrayCompressor } = require("../../helpers/chat");
  290. const messageArray = this.constructPrompt(promptArgs);
  291. return await messageArrayCompressor(this, messageArray, rawHistory);
  292. }
  293. }
  294. async function fetchApiPieModels(providedApiKey = null) {
  295. const apiKey = providedApiKey || process.env.APIPIE_LLM_API_KEY || null;
  296. return await fetch(`https://apipie.ai/v1/models`, {
  297. method: "GET",
  298. headers: {
  299. "Content-Type": "application/json",
  300. ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
  301. },
  302. })
  303. .then((res) => res.json())
  304. .then(({ data = [] }) => {
  305. const models = {};
  306. data.forEach((model) => {
  307. models[`${model.provider}/${model.model}`] = {
  308. id: `${model.provider}/${model.model}`,
  309. name: `${model.provider}/${model.model}`,
  310. organization: model.provider,
  311. subtype: model.subtype,
  312. maxLength: model.max_tokens,
  313. };
  314. });
  315. // Cache all response information
  316. if (!fs.existsSync(cacheFolder))
  317. fs.mkdirSync(cacheFolder, { recursive: true });
  318. fs.writeFileSync(
  319. path.resolve(cacheFolder, "models.json"),
  320. JSON.stringify(models),
  321. {
  322. encoding: "utf-8",
  323. }
  324. );
  325. fs.writeFileSync(
  326. path.resolve(cacheFolder, ".cached_at"),
  327. String(Number(new Date())),
  328. {
  329. encoding: "utf-8",
  330. }
  331. );
  332. return models;
  333. })
  334. .catch((e) => {
  335. console.error(e);
  336. return {};
  337. });
  338. }
  339. module.exports = {
  340. ApiPieLLM,
  341. fetchApiPieModels,
  342. };