Skip to content

Commit

Permalink
feat: add indexing methods (#21)
Browse files Browse the repository at this point in the history
* feat: add indexing methods

* lint

* respond to comments

* Update test_cloudsql_vectorstore_index.py

* fix tests

* Update src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* Update src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* fix

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
  • Loading branch information
averikitsch and kurtisvg committed Feb 14, 2024
1 parent 89a9a13 commit 8eae440
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 5 deletions.
54 changes: 52 additions & 2 deletions src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@
from __future__ import annotations

import json
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type
from typing import Any, Iterable, List, Optional, Tuple, Type

import numpy as np
from langchain_community.vectorstores.utils import maximal_marginal_relevance
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from .indexes import DEFAULT_DISTANCE_STRATEGY, DistanceStrategy, QueryOptions
from .indexes import (
DEFAULT_DISTANCE_STRATEGY,
DEFAULT_INDEX_NAME,
BaseIndex,
DistanceStrategy,
ExactNearestNeighbor,
QueryOptions,
)
from .postgresql_engine import PostgreSQLEngine


Expand Down Expand Up @@ -713,3 +720,46 @@ def max_marginal_relevance_search_with_score_by_vector(
**kwargs,
)
return self.engine.run_as_sync(coro)

async def aapply_vector_index(
self,
index: BaseIndex,
name: Optional[str] = None,
concurrently: bool = False,
) -> None:
if isinstance(index, ExactNearestNeighbor):
await self.adrop_vector_index()
return

filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
params = "WITH " + index.index_options()
function = index.distance_strategy.index_function
name = name or index.name
stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON {self.table_name} USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};"
if concurrently:
await self.engine._aexecute_outside_tx(stmt)
else:
await self.engine._aexecute(stmt)

async def areindex(self, index_name: str = DEFAULT_INDEX_NAME) -> None:
query = f"REINDEX INDEX {index_name};"
await self.engine._aexecute(query)

async def adrop_vector_index(
self,
index_name: str = DEFAULT_INDEX_NAME,
) -> None:
query = f"DROP INDEX IF EXISTS {index_name};"
await self.engine._aexecute(query)

async def is_valid_index(
self,
index_name: str = DEFAULT_INDEX_NAME,
) -> bool:
query = f"""
SELECT tablename, indexname
FROM pg_indexes
WHERE tablename = '{self.table_name}' AND indexname = '{index_name}';
"""
results = await self.engine._afetch(query)
return bool(len(results) == 1)
2 changes: 1 addition & 1 deletion src/langchain_google_cloud_sql_pg/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BaseIndex(ABC):
distance_strategy: DistanceStrategy = field(
default_factory=lambda: DistanceStrategy.COSINE_DISTANCE
)
partial_indexes: Optional[List] = None
partial_indexes: Optional[List[str]] = None

@abstractmethod
def index_options(self) -> str:
Expand Down
11 changes: 9 additions & 2 deletions src/langchain_google_cloud_sql_pg/postgresql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ async def _get_iam_principal_email(
The email address associated with the current authenticated IAM
principal.
"""
# refresh credentials if they are not valid
if not credentials.valid:
request = google.auth.transport.requests.Request()
credentials.refresh(request)
if hasattr(credentials, "_service_account_email"):
email = credentials._service_account_email
# call OAuth2 api to get IAM principal email associated with OAuth2 token
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
async with aiohttp.ClientSession() as client:
Expand All @@ -65,7 +66,7 @@ async def _get_iam_principal_email(
"Failed to automatically obtain authenticated IAM princpal's "
"email address using environment's ADC credentials!"
)
return email
return email.replace(".gserviceaccount.com", "")


@dataclass
Expand Down Expand Up @@ -158,6 +159,12 @@ async def _aexecute(self, query: str):
await conn.execute(text(query))
await conn.commit()

async def _aexecute_outside_tx(self, query: str):
"""Execute a SQL query."""
async with self._engine.connect() as conn:
await conn.execute(text("COMMIT"))
await conn.execute(text(query))

async def _afetch(self, query: str):
async with self._engine.connect() as conn:
"""Fetch results from a SQL query."""
Expand Down
127 changes: 127 additions & 0 deletions tests/test_cloudsql_vectorstore_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import uuid
from typing import List

import pytest
import pytest_asyncio
from langchain_community.embeddings import DeterministicFakeEmbedding
from langchain_core.documents import Document

from langchain_google_cloud_sql_pg import CloudSQLVectorStore, Column, PostgreSQLEngine
from langchain_google_cloud_sql_pg.indexes import (
DEFAULT_INDEX_NAME,
DistanceStrategy,
HNSWIndex,
IVFFlatIndex,
)

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)

texts = ["foo", "bar", "baz"]
ids = [str(uuid.uuid4()) for i in range(len(texts))]
metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))]
docs = [
Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
]

embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))]


def get_env_var(key: str, desc: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key} to: {desc}")
return v


@pytest.mark.asyncio(scope="class")
class TestIndex:
@pytest.fixture(scope="module")
def db_project(self) -> str:
return get_env_var("PROJECT_ID", "project id for google cloud")

@pytest.fixture(scope="module")
def db_region(self) -> str:
return get_env_var("REGION", "region for cloud sql instance")

@pytest.fixture(scope="module")
def db_instance(self) -> str:
return get_env_var("INSTANCE_ID", "instance for cloud sql")

@pytest.fixture(scope="module")
def db_name(self) -> str:
return get_env_var("DATABASE_ID", "instance for cloud sql")

@pytest_asyncio.fixture(scope="class")
async def engine(self, db_project, db_region, db_instance, db_name):
engine = await PostgreSQLEngine.afrom_instance(
project_id=db_project,
instance=db_instance,
region=db_region,
database=db_name,
)
yield engine

@pytest_asyncio.fixture(scope="class")
async def vs(self, engine):
await engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
vs = await CloudSQLVectorStore.create(
engine,
embedding_service=embeddings_service,
table_name=DEFAULT_TABLE,
)

await vs.aadd_texts(texts, ids=ids)
await vs.adrop_vector_index()
yield vs
await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
await engine._engine.dispose()

async def test_aapply_vector_index(self, vs):
index = HNSWIndex()
await vs.aapply_vector_index(index)
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)

async def test_areindex(self, vs):
if not await vs.is_valid_index(DEFAULT_INDEX_NAME):
index = HNSWIndex()
await vs.aapply_vector_index(index)
await vs.areindex()
await vs.areindex(DEFAULT_INDEX_NAME)
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)

async def test_dropindex(self, vs):
await vs.adrop_vector_index()
result = await vs.is_valid_index(DEFAULT_INDEX_NAME)
assert not result

async def test_aapply_vector_index_ivfflat(self, vs):
index = IVFFlatIndex(distance_strategy=DistanceStrategy.EUCLIDEAN)
await vs.aapply_vector_index(index, concurrently=True)
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
index = IVFFlatIndex(
name="secondindex",
distance_strategy=DistanceStrategy.INNER_PRODUCT,
)
await vs.aapply_vector_index(index)
assert await vs.is_valid_index("secondindex")
await vs.adrop_vector_index("secondindex")

0 comments on commit 8eae440

Please sign in to comment.