Skip to content

Commit

Permalink
Do not go through LLM to embed when embedding documents (#1428)
Browse files Browse the repository at this point in the history
  • Loading branch information
timothycarambat committed May 17, 2024
1 parent 01cf2fe commit cae6cee
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 66 deletions.
4 changes: 2 additions & 2 deletions server/utils/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ function getLLMProvider({ provider = null, model = null } = {}) {
}

function getEmbeddingEngineSelection() {
const { NativeEmbedder } = require("../EmbeddingEngines/native");
const engineSelection = process.env.EMBEDDING_ENGINE;
switch (engineSelection) {
case "openai":
Expand All @@ -117,7 +118,6 @@ function getEmbeddingEngineSelection() {
const { OllamaEmbedder } = require("../EmbeddingEngines/ollama");
return new OllamaEmbedder();
case "native":
const { NativeEmbedder } = require("../EmbeddingEngines/native");
return new NativeEmbedder();
case "lmstudio":
const { LMStudioEmbedder } = require("../EmbeddingEngines/lmstudio");
Expand All @@ -126,7 +126,7 @@ function getEmbeddingEngineSelection() {
const { CohereEmbedder } = require("../EmbeddingEngines/cohere");
return new CohereEmbedder();
default:
return null;
return new NativeEmbedder();
}
}

Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/astra/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");

const AstraDB = {
Expand Down Expand Up @@ -149,12 +145,13 @@ const AstraDB = {
return { vectorized: true, error: null };
}

const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -164,10 +161,9 @@ const AstraDB = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);

if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/chroma/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { parseAuthHeader } = require("../../http");
const { sourceIdentifier } = require("../../chats");

Expand Down Expand Up @@ -192,12 +188,13 @@ const Chroma = {
// We have to do this manually as opposed to using LangChains `Chroma.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -207,10 +204,9 @@ const Chroma = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = {
ids: [],
embeddings: [],
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/lance/index.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
const lancedb = require("vectordb");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
Expand Down Expand Up @@ -190,12 +186,13 @@ const LanceDb = {
// We have to do this manually as opposed to using LangChains `xyz.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -205,11 +202,10 @@ const LanceDb = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const submissions = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);

if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/milvus/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");

const Milvus = {
Expand Down Expand Up @@ -184,12 +180,13 @@ const Milvus = {
return { vectorized: true, error: null };
}

const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -199,10 +196,9 @@ const Milvus = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);

if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/pinecone/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");

const PineconeDB = {
Expand Down Expand Up @@ -135,12 +131,13 @@ const PineconeDB = {
// because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L167
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -150,10 +147,9 @@ const PineconeDB = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);

if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/qdrant/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");

const QDrant = {
Expand Down Expand Up @@ -209,12 +205,13 @@ const QDrant = {
// We have to do this manually as opposed to using LangChains `Qdrant.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -224,10 +221,9 @@ const QDrant = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = {
ids: [],
vectors: [],
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/weaviate/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const { v4: uuidv4 } = require("uuid");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { camelCase } = require("../../helpers/camelcase");
const { sourceIdentifier } = require("../../chats");

Expand Down Expand Up @@ -251,12 +247,13 @@ const Weaviate = {
// We have to do this manually as opposed to using LangChains `Chroma.fromDocuments`
// because we then cannot atomically control our namespace to granularly find/remove documents
// from vectordb.
const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -266,10 +263,9 @@ const Weaviate = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);
const submission = {
ids: [],
vectors: [],
Expand Down
12 changes: 4 additions & 8 deletions server/utils/vectorDbProviders/zilliz/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ const { TextSplitter } = require("../../TextSplitter");
const { SystemSettings } = require("../../../models/systemSettings");
const { v4: uuidv4 } = require("uuid");
const { storeVectorResult, cachedVectorInformation } = require("../../files");
const {
toChunks,
getLLMProvider,
getEmbeddingEngineSelection,
} = require("../../helpers");
const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
const { sourceIdentifier } = require("../../chats");

// Zilliz is basically a copy of Milvus DB class with a different constructor
Expand Down Expand Up @@ -185,12 +181,13 @@ const Zilliz = {
return { vectorized: true, error: null };
}

const EmbedderEngine = getEmbeddingEngineSelection();
const textSplitter = new TextSplitter({
chunkSize: TextSplitter.determineMaxChunkSize(
await SystemSettings.getValueOrFallback({
label: "text_splitter_chunk_size",
}),
getEmbeddingEngineSelection()?.embeddingMaxChunkLength
EmbedderEngine?.embeddingMaxChunkLength
),
chunkOverlap: await SystemSettings.getValueOrFallback(
{ label: "text_splitter_chunk_overlap" },
Expand All @@ -200,10 +197,9 @@ const Zilliz = {
const textChunks = await textSplitter.splitText(pageContent);

console.log("Chunks created from document:", textChunks.length);
const LLMConnector = getLLMProvider();
const documentVectors = [];
const vectors = [];
const vectorValues = await LLMConnector.embedChunks(textChunks);
const vectorValues = await EmbedderEngine.embedChunks(textChunks);

if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
Expand Down

0 comments on commit cae6cee

Please sign in to comment.