Skip to content

Commit

Permalink
feat: add Vector Store (#14)
Browse files Browse the repository at this point in the history
* feat: add PostgreSQLEngine

* lint

* add header

* Update src/langchain_google_cloud_sql_pg/postgresql_engine.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* Update src/langchain_google_cloud_sql_pg/postgresql_engine.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* Update src/langchain_google_cloud_sql_pg/postgresql_engine.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* Update src/langchain_google_cloud_sql_pg/postgresql_engine.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* Update src/langchain_google_cloud_sql_pg/postgresql_engine.py

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>

* add json column

* Update comments

* fix

* Update pyproject.toml

* clean up

* feat: add Vector Store

* Update pyproject.toml

* add tests

* lint

* update

* add delete

* respond to comments

* Update cloudsql_vectorstore.py

* test

* remove asyncio

* clean up

* sleep

* sync vs

* fix

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
  • Loading branch information
averikitsch and kurtisvg committed Feb 13, 2024
1 parent b181f65 commit f3e1127
Show file tree
Hide file tree
Showing 6 changed files with 563 additions and 11 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.8"
dependencies = [
"langchain==0.1.1",
"SQLAlchemy>=2.0.25",
"cloud-sql-python-connector[asyncpg]>=1.6.0",
"pgvector>=0.2.5"
"langchain-core>=0.1.1",
"pgvector>=0.2.5",
"SQLAlchemy>=2.0.25"
]

[project.urls]
Expand All @@ -23,6 +23,7 @@ test = [
"black==23.12.0",
"black[jupyter]==23.12.0",
"isort==5.13.2",
"langchain-community>=0.0.18",
"mypy==1.7.1",
"pytest-asyncio==0.23.0",
"pytest==7.4.4"
Expand Down
3 changes: 2 additions & 1 deletion src/langchain_google_cloud_sql_pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_pg.cloudsql_vectorstore import CloudSQLVectorStore
from langchain_google_cloud_sql_pg.postgresql_engine import Column, PostgreSQLEngine

__all__ = ["PostgreSQLEngine", "Column"]
__all__ = ["PostgreSQLEngine", "Column", "CloudSQLVectorStore"]
292 changes: 292 additions & 0 deletions src/langchain_google_cloud_sql_pg/cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

import json
from typing import Any, Awaitable, Iterable, List, Optional

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from .postgresql_engine import PostgreSQLEngine


class CloudSQLVectorStore(VectorStore):
"""Google Cloud SQL for PostgreSQL Vector Store class"""

__create_key = object()

def __init__(
self,
key,
engine: PostgreSQLEngine,
embedding_service: Embeddings,
table_name: str,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[str] = [],
id_column: str = "langchain_id",
metadata_json_column: str = "langchain_metadata",
store_metadata: bool = True,
):
if key != CloudSQLVectorStore.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
)

self.engine = engine
self.embedding_service = embedding_service
self.table_name = table_name
self.content_column = content_column
self.embedding_column = embedding_column
self.metadata_columns = metadata_columns
self.id_column = id_column
self.metadata_json_column = metadata_json_column
self.store_metadata = store_metadata

@classmethod
async def create(
cls,
engine: PostgreSQLEngine,
embedding_service: Embeddings,
table_name: str,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[str] = [],
ignore_metadata_columns: Optional[List[str]] = None,
id_column: str = "langchain_id",
metadata_json_column: str = "langchain_metadata",
):
"""Constructor for CloudSQLVectorStore.
Args:
engine (PostgreSQLEngine): AsyncEngine with pool connection to the postgres database. Required.
embedding_service (Embeddings): Text embedding model to use.
table_name (str): Name of the existing table or the table to be created.
id_column (str): Column that represents the Document's id. Defaults to "langchain_id".
content_column (str): Column that represent a Document’s page_content. Defaults to "content".
embedding_column (str): Column for embedding vectors.
The embedding is generated from the document value. Defaults to "embedding".
metadata_columns (List[str]): Column(s) that represent a document's metadata.
ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document’s metadata.
Can not be used with metadata_columns. Defaults to None.
metadata_json_column (str): Column to store metadata as JSON. Defaulst to "langchain_metadata".
"""
if metadata_columns and ignore_metadata_columns:
raise ValueError(
"Can not use both metadata_columns and ignore_metadata_columns."
)
# Get field type information
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'"
results = await engine._afetch(stmt)
columns = {}
for field in results:
columns[field["column_name"]] = field["data_type"]

# Check columns
if id_column not in columns:
raise ValueError(f"Id column, {id_column}, does not exist.")
if content_column not in columns:
raise ValueError(f"Content column, {content_column}, does not exist.")
content_type = columns[content_column]
if content_type != "text" and "char" not in content_type:
raise ValueError(
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
)
if embedding_column not in columns:
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
if columns[embedding_column] != "USER-DEFINED":
raise ValueError(
f"Embedding column, {embedding_column}, is not type Vector."
)
if metadata_json_column in columns:
store_metadata = True

# If using metadata_columns check to make sure column exists
for column in metadata_columns:
if column not in columns:
raise ValueError(f"Metadata column, {column}, does not exist.")

# If using ignore_metadata_columns, filter out known columns and set known metadata columns
all_columns = columns
if ignore_metadata_columns:
for column in ignore_metadata_columns:
del all_columns[column]

del all_columns[id_column]
del all_columns[content_column]
del all_columns[embedding_column]
metadata_columns = [k for k, _ in all_columns.keys()]

return cls(
cls.__create_key,
engine,
embedding_service,
table_name,
content_column,
embedding_column,
metadata_columns,
id_column,
metadata_json_column,
store_metadata,
)

@classmethod
def create_sync(
cls,
engine: PostgreSQLEngine,
embedding_service: Embeddings,
table_name: str,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[str] = [],
ignore_metadata_columns: Optional[List[str]] = None,
id_column: str = "langchain_id",
metadata_json_column: str = "langchain_metadata",
):
coro = cls.create(
engine,
embedding_service,
table_name,
content_column,
embedding_column,
metadata_columns,
ignore_metadata_columns,
id_column,
metadata_json_column,
)
return engine.run_as_sync(coro)

@property
def embeddings(self) -> Embeddings:
return self.embedding_service

async def _aadd_embeddings(
self,
texts: Iterable[str],
embeddings: List[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
if not ids:
ids = ["NULL" for _ in texts]
if not metadatas:
metadatas = [{} for _ in texts]
# Insert embeddings
for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas):
metadata_col_names = (
", " + ", ".join(self.metadata_columns)
if len(self.metadata_columns) > 0
else ""
)
insert_stmt = f"INSERT INTO {self.table_name}({self.id_column}, {self.content_column}, {self.embedding_column}{metadata_col_names}"
values_stmt = f" VALUES ('{id}','{content}','{embedding}'"
extra = metadata
for metadata_column in self.metadata_columns:
if metadata_column in metadata:
values_stmt += f",'{metadata[metadata_column]}'"
del extra[metadata_column]
else:
values_stmt += ",null"

insert_stmt += (
f", {self.metadata_json_column})" if self.store_metadata else ")"
)
values_stmt += f",'{json.dumps(extra)}')" if self.store_metadata else ")"
query = insert_stmt + values_stmt
await self.engine._aexecute(query)

return ids

async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
embeddings = self.embedding_service.embed_documents(list(texts))
ids = await self._aadd_embeddings(
texts, embeddings, metadatas=metadatas, ids=ids, **kwargs
)
return ids

async def aadd_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return ids

def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
return self.engine.run_as_sync(self.aadd_texts(texts, metadatas, ids, **kwargs))

def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
return self.engine.run_as_sync(self.aadd_documents(documents, ids, **kwargs))

async def adelete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Optional[bool]:
if not ids:
return False

id_list = ", ".join([f"'{id}'" for id in ids])
query = f"DELETE FROM {self.table_name} WHERE {self.id_column} in ({id_list})"
await self.engine._aexecute(query)
return True

def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Optional[bool]:
return self.engine.run_as_sync(self.adelete(ids, **kwargs))

# def run_as_sync(self, coro: Awaitable) -> Any:
# try:
# # asyncio.run(coro)
# loop = asyncio.get_running_loop()
# loop.run_until_complete(coro)

# except RuntimeError:
# return self.engine.run_as_sync(coro)

@classmethod
def from_texts(cls):
pass

def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
return []
16 changes: 9 additions & 7 deletions src/langchain_google_cloud_sql_pg/postgresql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def from_instance(
database: str,
) -> PostgreSQLEngine:
# Running a loop in a background thread allows us to support
# async methods from non-async enviroments
# async methods from non-async environments
loop = asyncio.new_event_loop()
thread = Thread(target=loop.run_forever, daemon=True)
thread.start()
coro = cls.afrom_instance(project_id, region, instance, database)
coro = cls._create(
project_id, region, instance, database, loop=loop, thread=thread
)
return asyncio.run_coroutine_threadsafe(coro, loop).result()

@classmethod
Expand All @@ -124,8 +126,8 @@ async def _create(
cls._connector = await create_async_connector()

# anonymous function to be used for SQLAlchemy 'creator' argument
def getconn() -> asyncpg.Connection:
conn = cls._connector.connect_async( # type: ignore
async def getconn() -> asyncpg.Connection:
conn = await cls._connector.connect_async( # type: ignore
f"{project_id}:{region}:{instance}",
"asyncpg",
user=iam_database_user,
Expand Down Expand Up @@ -165,7 +167,7 @@ async def _afetch(self, query: str):

return result_fetch

def run_as_sync(self, coro: Awaitable[T]): # TODO: add return type
def run_as_sync(self, coro: Awaitable[T]) -> T:
if not self._loop:
raise Exception("Engine was initialized async.")
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
Expand All @@ -177,7 +179,7 @@ async def init_vectorstore_table(
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
metadata_json_columns: str = "langchain_metadata",
metadata_json_column: str = "langchain_metadata",
id_column: str = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
Expand All @@ -196,7 +198,7 @@ async def init_vectorstore_table(
"NOT NULL" if not column.nullable else ""
)
if store_metadata:
query += f",\n{metadata_json_columns} JSON"
query += f",\n{metadata_json_column} JSON"
query += "\n);"

await self._aexecute(query)
Loading

0 comments on commit f3e1127

Please sign in to comment.