You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
7.0 KiB

11 months ago
  1. const path = require("path");
  2. const fs = require("fs");
  3. const { toChunks } = require("../../helpers");
  4. const { v4 } = require("uuid");
  5. class NativeEmbedder {
  6. // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
  7. // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
  8. // and may go offline at any time at Mintplex Labs's discretion.
  9. #fallbackHost = "https://cdn.anythingllm.com/support/models/";
  10. constructor() {
  11. // Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
  12. this.model = "Xenova/all-MiniLM-L6-v2";
  13. this.cacheDir = path.resolve(
  14. process.env.STORAGE_DIR
  15. ? path.resolve(process.env.STORAGE_DIR, `models`)
  16. : path.resolve(__dirname, `../../../storage/models`)
  17. );
  18. this.modelPath = path.resolve(this.cacheDir, "Xenova", "all-MiniLM-L6-v2");
  19. this.modelDownloaded = fs.existsSync(this.modelPath);
  20. // Limit of how many strings we can process in a single pass to stay with resource or network limits
  21. this.maxConcurrentChunks = 25;
  22. this.embeddingMaxChunkLength = 1_000;
  23. // Make directory when it does not exist in existing installations
  24. if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
  25. this.log("Initialized");
  26. }
  27. log(text, ...args) {
  28. console.log(`\x1b[36m[NativeEmbedder]\x1b[0m ${text}`, ...args);
  29. }
  30. #tempfilePath() {
  31. const filename = `${v4()}.tmp`;
  32. const tmpPath = process.env.STORAGE_DIR
  33. ? path.resolve(process.env.STORAGE_DIR, "tmp")
  34. : path.resolve(__dirname, `../../../storage/tmp`);
  35. if (!fs.existsSync(tmpPath)) fs.mkdirSync(tmpPath, { recursive: true });
  36. return path.resolve(tmpPath, filename);
  37. }
  38. async #writeToTempfile(filePath, data) {
  39. try {
  40. await fs.promises.appendFile(filePath, data, { encoding: "utf8" });
  41. } catch (e) {
  42. console.error(`Error writing to tempfile: ${e}`);
  43. }
  44. }
  45. async #fetchWithHost(hostOverride = null) {
  46. try {
  47. // Convert ESM to CommonJS via import so we can load this library.
  48. const pipeline = (...args) =>
  49. import("@xenova/transformers").then(({ pipeline, env }) => {
  50. if (!this.modelDownloaded) {
  51. // if model is not downloaded, we will log where we are fetching from.
  52. if (hostOverride) {
  53. env.remoteHost = hostOverride;
  54. env.remotePathTemplate = "{model}/"; // Our S3 fallback url does not support revision File structure.
  55. }
  56. this.log(`Downloading ${this.model} from ${env.remoteHost}`);
  57. }
  58. return pipeline(...args);
  59. });
  60. return {
  61. pipeline: await pipeline("feature-extraction", this.model, {
  62. cache_dir: this.cacheDir,
  63. ...(!this.modelDownloaded
  64. ? {
  65. // Show download progress if we need to download any files
  66. progress_callback: (data) => {
  67. if (!data.hasOwnProperty("progress")) return;
  68. console.log(
  69. `\x1b[36m[NativeEmbedder - Downloading model]\x1b[0m ${
  70. data.file
  71. } ${~~data?.progress}%`
  72. );
  73. },
  74. }
  75. : {}),
  76. }),
  77. retry: false,
  78. error: null,
  79. };
  80. } catch (error) {
  81. return {
  82. pipeline: null,
  83. retry: hostOverride === null ? this.#fallbackHost : false,
  84. error,
  85. };
  86. }
  87. }
  88. // This function will do a single fallback attempt (not recursive on purpose) to try to grab the embedder model on first embed
  89. // since at time, some clients cannot properly download the model from HF servers due to a number of reasons (IP, VPN, etc).
  90. // Given this model is critical and nobody reads the GitHub issues before submitting the bug, we get the same bug
  91. // report 20 times a day: https://github.com/Mintplex-Labs/anything-llm/issues/821
  92. // So to attempt to monkey-patch this we have a single fallback URL to help alleviate duplicate bug reports.
  93. async embedderClient() {
  94. if (!this.modelDownloaded)
  95. this.log(
  96. "The native embedding model has never been run and will be downloaded right now. Subsequent runs will be faster. (~23MB)"
  97. );
  98. let fetchResponse = await this.#fetchWithHost();
  99. if (fetchResponse.pipeline !== null) {
  100. this.modelDownloaded = true;
  101. return fetchResponse.pipeline;
  102. }
  103. this.log(
  104. `Failed to download model from primary URL. Using fallback ${fetchResponse.retry}`
  105. );
  106. if (!!fetchResponse.retry)
  107. fetchResponse = await this.#fetchWithHost(fetchResponse.retry);
  108. if (fetchResponse.pipeline !== null) {
  109. this.modelDownloaded = true;
  110. return fetchResponse.pipeline;
  111. }
  112. throw fetchResponse.error;
  113. }
  114. async embedTextInput(textInput) {
  115. const result = await this.embedChunks(
  116. Array.isArray(textInput) ? textInput : [textInput]
  117. );
  118. return result?.[0] || [];
  119. }
  120. // If you are thinking you want to edit this function - you probably don't.
  121. // This process was benchmarked heavily on a t3.small (2GB RAM 1vCPU)
  122. // and without careful memory management for the V8 garbage collector
  123. // this function will likely result in an OOM on any resource-constrained deployment.
  124. // To help manage very large documents we run a concurrent write-log each iteration
  125. // to keep the embedding result out of memory. The `maxConcurrentChunk` is set to 25,
  126. // as 50 seems to overflow no matter what. Given the above, memory use hovers around ~30%
  127. // during a very large document (>100K words) but can spike up to 70% before gc.
  128. // This seems repeatable for all document sizes.
  129. // While this does take a while, it is zero set up and is 100% free and on-instance.
  130. // It still may crash depending on other elements at play - so no promises it works under all conditions.
  131. async embedChunks(textChunks = []) {
  132. const tmpFilePath = this.#tempfilePath();
  133. const chunks = toChunks(textChunks, this.maxConcurrentChunks);
  134. const chunkLen = chunks.length;
  135. for (let [idx, chunk] of chunks.entries()) {
  136. if (idx === 0) await this.#writeToTempfile(tmpFilePath, "[");
  137. let data;
  138. let pipeline = await this.embedderClient();
  139. let output = await pipeline(chunk, {
  140. pooling: "mean",
  141. normalize: true,
  142. });
  143. if (output.length === 0) {
  144. pipeline = null;
  145. output = null;
  146. data = null;
  147. continue;
  148. }
  149. data = JSON.stringify(output.tolist());
  150. await this.#writeToTempfile(tmpFilePath, data);
  151. this.log(`Embedded Chunk ${idx + 1} of ${chunkLen}`);
  152. if (chunkLen - 1 !== idx) await this.#writeToTempfile(tmpFilePath, ",");
  153. if (chunkLen - 1 === idx) await this.#writeToTempfile(tmpFilePath, "]");
  154. pipeline = null;
  155. output = null;
  156. data = null;
  157. }
  158. const embeddingResults = JSON.parse(
  159. fs.readFileSync(tmpFilePath, { encoding: "utf-8" })
  160. );
  161. fs.rmSync(tmpFilePath, { force: true });
  162. return embeddingResults.length > 0 ? embeddingResults.flat() : null;
  163. }
  164. }
  165. module.exports = {
  166. NativeEmbedder,
  167. };