You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

398 lines
13 KiB

11 months ago
  1. const { AstraDB: AstraClient } = require("@datastax/astra-db-ts");
  2. const { TextSplitter } = require("../../TextSplitter");
  3. const { SystemSettings } = require("../../../models/systemSettings");
  4. const { storeVectorResult, cachedVectorInformation } = require("../../files");
  5. const { v4: uuidv4 } = require("uuid");
  6. const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
  7. const { sourceIdentifier } = require("../../chats");
  8. const AstraDB = {
  9. name: "AstraDB",
  10. connect: async function () {
  11. if (process.env.VECTOR_DB !== "astra")
  12. throw new Error("AstraDB::Invalid ENV settings");
  13. const client = new AstraClient(
  14. process?.env?.ASTRA_DB_APPLICATION_TOKEN,
  15. process?.env?.ASTRA_DB_ENDPOINT
  16. );
  17. return { client };
  18. },
  19. heartbeat: async function () {
  20. return { heartbeat: Number(new Date()) };
  21. },
  22. // Astra interface will return a valid collection object even if the collection
  23. // does not actually exist. So we run a simple check which will always throw
  24. // when the table truly does not exist. Faster than iterating all collections.
  25. isRealCollection: async function (astraCollection = null) {
  26. if (!astraCollection) return false;
  27. return await astraCollection
  28. .countDocuments()
  29. .then(() => true)
  30. .catch(() => false);
  31. },
  32. totalVectors: async function () {
  33. const { client } = await this.connect();
  34. const collectionNames = await this.allNamespaces(client);
  35. var totalVectors = 0;
  36. for (const name of collectionNames) {
  37. const collection = await client.collection(name).catch(() => null);
  38. const count = await collection.countDocuments().catch(() => 0);
  39. totalVectors += count ? count : 0;
  40. }
  41. return totalVectors;
  42. },
  43. namespaceCount: async function (_namespace = null) {
  44. const { client } = await this.connect();
  45. const namespace = await this.namespace(client, _namespace);
  46. return namespace?.vectorCount || 0;
  47. },
  48. namespace: async function (client, namespace = null) {
  49. if (!namespace) throw new Error("No namespace value provided.");
  50. const collection = await client.collection(namespace).catch(() => null);
  51. if (!(await this.isRealCollection(collection))) return null;
  52. const count = await collection.countDocuments().catch((e) => {
  53. console.error("Astra::namespaceExists", e.message);
  54. return null;
  55. });
  56. return {
  57. name: namespace,
  58. ...collection,
  59. vectorCount: typeof count === "number" ? count : 0,
  60. };
  61. },
  62. hasNamespace: async function (namespace = null) {
  63. if (!namespace) return false;
  64. const { client } = await this.connect();
  65. return await this.namespaceExists(client, namespace);
  66. },
  67. namespaceExists: async function (client, namespace = null) {
  68. if (!namespace) throw new Error("No namespace value provided.");
  69. const collection = await client.collection(namespace);
  70. return await this.isRealCollection(collection);
  71. },
  72. deleteVectorsInNamespace: async function (client, namespace = null) {
  73. await client.dropCollection(namespace);
  74. return true;
  75. },
  76. // AstraDB requires a dimension aspect for collection creation
  77. // we pass this in from the first chunk to infer the dimensions like other
  78. // providers do.
  79. getOrCreateCollection: async function (client, namespace, dimensions = null) {
  80. const isExists = await this.namespaceExists(client, namespace);
  81. if (!isExists) {
  82. if (!dimensions)
  83. throw new Error(
  84. `AstraDB:getOrCreateCollection Unable to infer vector dimension from input. Open an issue on GitHub for support.`
  85. );
  86. await client.createCollection(namespace, {
  87. vector: {
  88. dimension: dimensions,
  89. metric: "cosine",
  90. },
  91. });
  92. }
  93. return await client.collection(namespace);
  94. },
  95. addDocumentToNamespace: async function (
  96. namespace,
  97. documentData = {},
  98. fullFilePath = null,
  99. skipCache = false
  100. ) {
  101. const { DocumentVectors } = require("../../../models/vectors");
  102. try {
  103. let vectorDimension = null;
  104. const { pageContent, docId, ...metadata } = documentData;
  105. if (!pageContent || pageContent.length == 0) return false;
  106. console.log("Adding new vectorized document into namespace", namespace);
  107. if (!skipCache) {
  108. const cacheResult = await cachedVectorInformation(fullFilePath);
  109. if (cacheResult.exists) {
  110. const { client } = await this.connect();
  111. const { chunks } = cacheResult;
  112. const documentVectors = [];
  113. vectorDimension = chunks[0][0].values.length || null;
  114. const collection = await this.getOrCreateCollection(
  115. client,
  116. namespace,
  117. vectorDimension
  118. );
  119. if (!(await this.isRealCollection(collection)))
  120. throw new Error("Failed to create new AstraDB collection!", {
  121. namespace,
  122. });
  123. for (const chunk of chunks) {
  124. // Before sending to Astra and saving the records to our db
  125. // we need to assign the id of each chunk that is stored in the cached file.
  126. const newChunks = chunk.map((chunk) => {
  127. const _id = uuidv4();
  128. documentVectors.push({ docId, vectorId: _id });
  129. return {
  130. _id: _id,
  131. $vector: chunk.values,
  132. metadata: chunk.metadata || {},
  133. };
  134. });
  135. await collection.insertMany(newChunks);
  136. }
  137. await DocumentVectors.bulkInsert(documentVectors);
  138. return { vectorized: true, error: null };
  139. }
  140. }
  141. const EmbedderEngine = getEmbeddingEngineSelection();
  142. const textSplitter = new TextSplitter({
  143. chunkSize: TextSplitter.determineMaxChunkSize(
  144. await SystemSettings.getValueOrFallback({
  145. label: "text_splitter_chunk_size",
  146. }),
  147. EmbedderEngine?.embeddingMaxChunkLength
  148. ),
  149. chunkOverlap: await SystemSettings.getValueOrFallback(
  150. { label: "text_splitter_chunk_overlap" },
  151. 20
  152. ),
  153. chunkHeaderMeta: TextSplitter.buildHeaderMeta(metadata),
  154. });
  155. const textChunks = await textSplitter.splitText(pageContent);
  156. console.log("Chunks created from document:", textChunks.length);
  157. const documentVectors = [];
  158. const vectors = [];
  159. const vectorValues = await EmbedderEngine.embedChunks(textChunks);
  160. if (!!vectorValues && vectorValues.length > 0) {
  161. for (const [i, vector] of vectorValues.entries()) {
  162. if (!vectorDimension) vectorDimension = vector.length;
  163. const vectorRecord = {
  164. _id: uuidv4(),
  165. $vector: vector,
  166. metadata: { ...metadata, text: textChunks[i] },
  167. };
  168. vectors.push(vectorRecord);
  169. documentVectors.push({ docId, vectorId: vectorRecord._id });
  170. }
  171. } else {
  172. throw new Error(
  173. "Could not embed document chunks! This document will not be recorded."
  174. );
  175. }
  176. const { client } = await this.connect();
  177. const collection = await this.getOrCreateCollection(
  178. client,
  179. namespace,
  180. vectorDimension
  181. );
  182. if (!(await this.isRealCollection(collection)))
  183. throw new Error("Failed to create new AstraDB collection!", {
  184. namespace,
  185. });
  186. if (vectors.length > 0) {
  187. const chunks = [];
  188. console.log("Inserting vectorized chunks into Astra DB.");
  189. // AstraDB has maximum upsert size of 20 records per-request so we have to use a lower chunk size here
  190. // in order to do the queries - this takes a lot more time than other providers but there
  191. // is no way around it. This will save the vector-cache with the same layout, so we don't
  192. // have to chunk again for cached files.
  193. for (const chunk of toChunks(vectors, 20)) {
  194. chunks.push(
  195. chunk.map((c) => {
  196. return { id: c._id, values: c.$vector, metadata: c.metadata };
  197. })
  198. );
  199. await collection.insertMany(chunk);
  200. }
  201. await storeVectorResult(chunks, fullFilePath);
  202. }
  203. await DocumentVectors.bulkInsert(documentVectors);
  204. return { vectorized: true, error: null };
  205. } catch (e) {
  206. console.error("addDocumentToNamespace", e.message);
  207. return { vectorized: false, error: e.message };
  208. }
  209. },
  210. deleteDocumentFromNamespace: async function (namespace, docId) {
  211. const { DocumentVectors } = require("../../../models/vectors");
  212. const { client } = await this.connect();
  213. if (!(await this.namespaceExists(client, namespace)))
  214. throw new Error(
  215. "Invalid namespace - has it been collected and populated yet?"
  216. );
  217. const collection = await client.collection(namespace);
  218. const knownDocuments = await DocumentVectors.where({ docId });
  219. if (knownDocuments.length === 0) return;
  220. const vectorIds = knownDocuments.map((doc) => doc.vectorId);
  221. for (const id of vectorIds) {
  222. await collection.deleteMany({
  223. _id: id,
  224. });
  225. }
  226. const indexes = knownDocuments.map((doc) => doc.id);
  227. await DocumentVectors.deleteIds(indexes);
  228. return true;
  229. },
  230. performSimilaritySearch: async function ({
  231. namespace = null,
  232. input = "",
  233. LLMConnector = null,
  234. similarityThreshold = 0.25,
  235. topN = 4,
  236. filterIdentifiers = [],
  237. }) {
  238. if (!namespace || !input || !LLMConnector)
  239. throw new Error("Invalid request to performSimilaritySearch.");
  240. const { client } = await this.connect();
  241. if (!(await this.namespaceExists(client, namespace))) {
  242. return {
  243. contextTexts: [],
  244. sources: [],
  245. message:
  246. "Invalid query - no namespace found for workspace in vector db!",
  247. };
  248. }
  249. const queryVector = await LLMConnector.embedTextInput(input);
  250. const { contextTexts, sourceDocuments } = await this.similarityResponse({
  251. client,
  252. namespace,
  253. queryVector,
  254. similarityThreshold,
  255. topN,
  256. filterIdentifiers,
  257. });
  258. const sources = sourceDocuments.map((metadata, i) => {
  259. return { ...metadata, text: contextTexts[i] };
  260. });
  261. return {
  262. contextTexts,
  263. sources: this.curateSources(sources),
  264. message: false,
  265. };
  266. },
  267. similarityResponse: async function ({
  268. client,
  269. namespace,
  270. queryVector,
  271. similarityThreshold = 0.25,
  272. topN = 4,
  273. filterIdentifiers = [],
  274. }) {
  275. const result = {
  276. contextTexts: [],
  277. sourceDocuments: [],
  278. scores: [],
  279. };
  280. const collection = await client.collection(namespace);
  281. const responses = await collection
  282. .find(
  283. {},
  284. {
  285. sort: { $vector: queryVector },
  286. limit: topN,
  287. includeSimilarity: true,
  288. }
  289. )
  290. .toArray();
  291. responses.forEach((response) => {
  292. if (response.$similarity < similarityThreshold) return;
  293. if (filterIdentifiers.includes(sourceIdentifier(response.metadata))) {
  294. console.log(
  295. "AstraDB: A source was filtered from context as it's parent document is pinned."
  296. );
  297. return;
  298. }
  299. result.contextTexts.push(response.metadata.text);
  300. result.sourceDocuments.push(response);
  301. result.scores.push(response.$similarity);
  302. });
  303. return result;
  304. },
  305. allNamespaces: async function (client) {
  306. try {
  307. let header = new Headers();
  308. header.append("Token", client?.httpClient?.applicationToken);
  309. header.append("Content-Type", "application/json");
  310. let raw = JSON.stringify({
  311. findCollections: {},
  312. });
  313. let requestOptions = {
  314. method: "POST",
  315. headers: header,
  316. body: raw,
  317. redirect: "follow",
  318. };
  319. const call = await fetch(client?.httpClient?.baseUrl, requestOptions);
  320. const resp = await call?.text();
  321. const collections = resp ? JSON.parse(resp)?.status?.collections : [];
  322. return collections;
  323. } catch (e) {
  324. console.error("Astra::AllNamespace", e);
  325. return [];
  326. }
  327. },
  328. "namespace-stats": async function (reqBody = {}) {
  329. const { namespace = null } = reqBody;
  330. if (!namespace) throw new Error("namespace required");
  331. const { client } = await this.connect();
  332. if (!(await this.namespaceExists(client, namespace)))
  333. throw new Error("Namespace by that name does not exist.");
  334. const stats = await this.namespace(client, namespace);
  335. return stats
  336. ? stats
  337. : { message: "No stats were able to be fetched from DB for namespace" };
  338. },
  339. "delete-namespace": async function (reqBody = {}) {
  340. const { namespace = null } = reqBody;
  341. const { client } = await this.connect();
  342. if (!(await this.namespaceExists(client, namespace)))
  343. throw new Error("Namespace by that name does not exist.");
  344. const details = await this.namespace(client, namespace);
  345. await this.deleteVectorsInNamespace(client, namespace);
  346. return {
  347. message: `Namespace ${namespace} was deleted along with ${
  348. details?.vectorCount || "all"
  349. } vectors.`,
  350. };
  351. },
  352. curateSources: function (sources = []) {
  353. const documents = [];
  354. for (const source of sources) {
  355. if (Object.keys(source).length > 0) {
  356. const metadata = source.hasOwnProperty("metadata")
  357. ? source.metadata
  358. : source;
  359. documents.push({
  360. ...metadata,
  361. });
  362. }
  363. }
  364. return documents;
  365. },
  366. };
  367. module.exports.AstraDB = AstraDB;