You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.4 KiB

11 months ago
  1. const { toChunks } = require("../../helpers");
  2. class GenericOpenAiEmbedder {
  3. constructor() {
  4. if (!process.env.EMBEDDING_BASE_PATH)
  5. throw new Error(
  6. "GenericOpenAI must have a valid base path to use for the api."
  7. );
  8. const { OpenAI: OpenAIApi } = require("openai");
  9. this.basePath = process.env.EMBEDDING_BASE_PATH;
  10. this.openai = new OpenAIApi({
  11. baseURL: this.basePath,
  12. apiKey: process.env.GENERIC_OPEN_AI_EMBEDDING_API_KEY ?? null,
  13. });
  14. this.model = process.env.EMBEDDING_MODEL_PREF ?? null;
  15. // this.maxConcurrentChunks is delegated to the getter below.
  16. // Refer to your specific model and provider you use this class with to determine a valid maxChunkLength
  17. this.embeddingMaxChunkLength = 8_191;
  18. }
  19. /**
  20. * returns the `GENERIC_OPEN_AI_EMBEDDING_MAX_CONCURRENT_CHUNKS` env variable as a number
  21. * or 500 if the env variable is not set or is not a number.
  22. * @returns {number}
  23. */
  24. get maxConcurrentChunks() {
  25. if (!process.env.GENERIC_OPEN_AI_EMBEDDING_MAX_CONCURRENT_CHUNKS)
  26. return 500;
  27. if (
  28. isNaN(Number(process.env.GENERIC_OPEN_AI_EMBEDDING_MAX_CONCURRENT_CHUNKS))
  29. )
  30. return 500;
  31. return Number(process.env.GENERIC_OPEN_AI_EMBEDDING_MAX_CONCURRENT_CHUNKS);
  32. }
  33. async embedTextInput(textInput) {
  34. const result = await this.embedChunks(
  35. Array.isArray(textInput) ? textInput : [textInput]
  36. );
  37. return result?.[0] || [];
  38. }
  39. async embedChunks(textChunks = []) {
  40. // Because there is a hard POST limit on how many chunks can be sent at once to OpenAI (~8mb)
  41. // we concurrently execute each max batch of text chunks possible.
  42. // Refer to constructor maxConcurrentChunks for more info.
  43. const embeddingRequests = [];
  44. for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
  45. embeddingRequests.push(
  46. new Promise((resolve) => {
  47. this.openai.embeddings
  48. .create({
  49. model: this.model,
  50. input: chunk,
  51. })
  52. .then((result) => {
  53. resolve({ data: result?.data, error: null });
  54. })
  55. .catch((e) => {
  56. e.type =
  57. e?.response?.data?.error?.code ||
  58. e?.response?.status ||
  59. "failed_to_embed";
  60. e.message = e?.response?.data?.error?.message || e.message;
  61. resolve({ data: [], error: e });
  62. });
  63. })
  64. );
  65. }
  66. const { data = [], error = null } = await Promise.all(
  67. embeddingRequests
  68. ).then((results) => {
  69. // If any errors were returned from OpenAI abort the entire sequence because the embeddings
  70. // will be incomplete.
  71. const errors = results
  72. .filter((res) => !!res.error)
  73. .map((res) => res.error)
  74. .flat();
  75. if (errors.length > 0) {
  76. let uniqueErrors = new Set();
  77. errors.map((error) =>
  78. uniqueErrors.add(`[${error.type}]: ${error.message}`)
  79. );
  80. return {
  81. data: [],
  82. error: Array.from(uniqueErrors).join(", "),
  83. };
  84. }
  85. return {
  86. data: results.map((res) => res?.data || []).flat(),
  87. error: null,
  88. };
  89. });
  90. if (!!error) throw new Error(`GenericOpenAI Failed to embed: ${error}`);
  91. return data.length > 0 &&
  92. data.every((embd) => embd.hasOwnProperty("embedding"))
  93. ? data.map((embd) => embd.embedding)
  94. : null;
  95. }
  96. }
  97. module.exports = {
  98. GenericOpenAiEmbedder,
  99. };