You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
2.7 KiB

11 months ago
  1. const { toChunks, maximumChunkLength } = require("../../helpers");
  2. class LocalAiEmbedder {
  3. constructor() {
  4. if (!process.env.EMBEDDING_BASE_PATH)
  5. throw new Error("No embedding base path was set.");
  6. if (!process.env.EMBEDDING_MODEL_PREF)
  7. throw new Error("No embedding model was set.");
  8. const { OpenAI: OpenAIApi } = require("openai");
  9. this.openai = new OpenAIApi({
  10. baseURL: process.env.EMBEDDING_BASE_PATH,
  11. apiKey: process.env.LOCAL_AI_API_KEY ?? null,
  12. });
  13. // Limit of how many strings we can process in a single pass to stay with resource or network limits
  14. this.maxConcurrentChunks = 50;
  15. this.embeddingMaxChunkLength = maximumChunkLength();
  16. }
  17. async embedTextInput(textInput) {
  18. const result = await this.embedChunks(
  19. Array.isArray(textInput) ? textInput : [textInput]
  20. );
  21. return result?.[0] || [];
  22. }
  23. async embedChunks(textChunks = []) {
  24. const embeddingRequests = [];
  25. for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
  26. embeddingRequests.push(
  27. new Promise((resolve) => {
  28. this.openai.embeddings
  29. .create({
  30. model: process.env.EMBEDDING_MODEL_PREF,
  31. input: chunk,
  32. })
  33. .then((result) => {
  34. resolve({ data: result?.data, error: null });
  35. })
  36. .catch((e) => {
  37. e.type =
  38. e?.response?.data?.error?.code ||
  39. e?.response?.status ||
  40. "failed_to_embed";
  41. e.message = e?.response?.data?.error?.message || e.message;
  42. resolve({ data: [], error: e });
  43. });
  44. })
  45. );
  46. }
  47. const { data = [], error = null } = await Promise.all(
  48. embeddingRequests
  49. ).then((results) => {
  50. // If any errors were returned from LocalAI abort the entire sequence because the embeddings
  51. // will be incomplete.
  52. const errors = results
  53. .filter((res) => !!res.error)
  54. .map((res) => res.error)
  55. .flat();
  56. if (errors.length > 0) {
  57. let uniqueErrors = new Set();
  58. errors.map((error) =>
  59. uniqueErrors.add(`[${error.type}]: ${error.message}`)
  60. );
  61. return {
  62. data: [],
  63. error: Array.from(uniqueErrors).join(", "),
  64. };
  65. }
  66. return {
  67. data: results.map((res) => res?.data || []).flat(),
  68. error: null,
  69. };
  70. });
  71. if (!!error) throw new Error(`LocalAI Failed to embed: ${error}`);
  72. return data.length > 0 &&
  73. data.every((embd) => embd.hasOwnProperty("embedding"))
  74. ? data.map((embd) => embd.embedding)
  75. : null;
  76. }
  77. }
  78. module.exports = {
  79. LocalAiEmbedder,
  80. };