|
|
const path = require("path");const fs = require("fs");
class NativeEmbeddingReranker { static #model = null; static #tokenizer = null; static #transformers = null;
// This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
// endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
// and may go offline at any time at Mintplex Labs's discretion.
#fallbackHost = "https://cdn.anythingllm.com/support/models/";
constructor() { // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
// Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
this.model = "Xenova/ms-marco-MiniLM-L-6-v2"; this.cacheDir = path.resolve( process.env.STORAGE_DIR ? path.resolve(process.env.STORAGE_DIR, `models`) : path.resolve(__dirname, `../../../storage/models`) ); this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/")); // Make directory when it does not exist in existing installations
if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
this.modelDownloaded = fs.existsSync( path.resolve(this.cacheDir, this.model) ); this.log("Initialized"); }
log(text, ...args) { console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args); }
/** * This function will return the host of the current reranker suite. * If the reranker suite is not initialized, it will return the default HF host. * @returns {string} The host of the current reranker suite. */ get host() { if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co"; try { return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host; } catch (e) { return this.#fallbackHost; } }
/** * This function will preload the reranker suite and tokenizer. * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such * to avoid having to wait for the models to download on the first rerank call. */ async preload() { try { this.log(`Preloading reranker suite...`); await this.initClient(); this.log( `Preloaded reranker suite. Reranking is available as a service now.` ); return; } catch (e) { console.error(e); this.log( `Failed to preload reranker suite. Reranking will be available on the first rerank call.` ); return; } }
async initClient() { if (NativeEmbeddingReranker.#transformers) { this.log(`Reranker suite already initialized - reusing.`); return; }
await import("@xenova/transformers").then( async ({ AutoModelForSequenceClassification, AutoTokenizer, env }) => { this.log(`Loading reranker suite...`); NativeEmbeddingReranker.#transformers = { AutoModelForSequenceClassification, AutoTokenizer, env, }; // Attempt to load the model and tokenizer in this order:
// 1. From local file system cache
// 2. Download and cache from remote host (hf.co)
// 3. Download and cache from fallback host (cdn.anythingllm.com)
await this.#getPreTrainedModel(); await this.#getPreTrainedTokenizer(); } ); return; }
/** * This function will load the model from the local file system cache, or download and cache it from the remote host. * If the model is not found in the local file system cache, it will download and cache it from the remote host. * If the model is not found in the remote host, it will download and cache it from the fallback host. * @returns {Promise<any>} The loaded model. */ async #getPreTrainedModel() { if (NativeEmbeddingReranker.#model) { this.log(`Loading model from singleton...`); return NativeEmbeddingReranker.#model; }
try { const model = await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained( this.model, { progress_callback: (p) => { if (!this.modelDownloaded && p.status === "progress") { this.log( `[${this.host}] Loading model ${this.model}... ${p?.progress}%` ); } }, cache_dir: this.cacheDir, } ); this.log(`Loaded model ${this.model}`); NativeEmbeddingReranker.#model = model; return model; } catch (e) { this.log( `Failed to load model ${this.model} from ${this.host}.`, e.message, e.stack ); if ( NativeEmbeddingReranker.#transformers.env.remoteHost === this.#fallbackHost ) { this.log(`Failed to load model ${this.model} from fallback host.`); throw e; }
this.log(`Falling back to fallback host. ${this.#fallbackHost}`); NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; return await this.#getPreTrainedModel(); } }
/** * This function will load the tokenizer from the local file system cache, or download and cache it from the remote host. * If the tokenizer is not found in the local file system cache, it will download and cache it from the remote host. * If the tokenizer is not found in the remote host, it will download and cache it from the fallback host. * @returns {Promise<any>} The loaded tokenizer. */ async #getPreTrainedTokenizer() { if (NativeEmbeddingReranker.#tokenizer) { this.log(`Loading tokenizer from singleton...`); return NativeEmbeddingReranker.#tokenizer; }
try { const tokenizer = await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained( this.model, { progress_callback: (p) => { if (!this.modelDownloaded && p.status === "progress") { this.log( `[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%` ); } }, cache_dir: this.cacheDir, } ); this.log(`Loaded tokenizer ${this.model}`); NativeEmbeddingReranker.#tokenizer = tokenizer; return tokenizer; } catch (e) { this.log( `Failed to load tokenizer ${this.model} from ${this.host}.`, e.message, e.stack ); if ( NativeEmbeddingReranker.#transformers.env.remoteHost === this.#fallbackHost ) { this.log(`Failed to load tokenizer ${this.model} from fallback host.`); throw e; }
this.log(`Falling back to fallback host. ${this.#fallbackHost}`); NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost; NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/"; return await this.#getPreTrainedTokenizer(); } }
/** * Reranks a list of documents based on the query. * @param {string} query - The query to rerank the documents against. * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search. * @param {Object} options - The options for the reranking. * @param {number} options.topK - The number of top documents to return. * @returns {Promise<any[]>} - The reranked list of documents. */ async rerank(query, documents, options = { topK: 4 }) { await this.initClient(); const model = NativeEmbeddingReranker.#model; const tokenizer = NativeEmbeddingReranker.#tokenizer;
const start = Date.now(); this.log(`Reranking ${documents.length} documents...`); const inputs = tokenizer(new Array(documents.length).fill(query), { text_pair: documents.map((doc) => doc.text), padding: true, truncation: true, }); const { logits } = await model(inputs); const reranked = logits .sigmoid() .tolist() .map(([score], i) => ({ rerank_corpus_id: i, rerank_score: score, ...documents[i], })) .sort((a, b) => b.rerank_score - a.rerank_score) .slice(0, options.topK);
this.log( `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms` ); return reranked; }}
module.exports = { NativeEmbeddingReranker,};
|