Skip to content

Commit

Permalink
fix: missing quote in table name (#123)
Browse files Browse the repository at this point in the history
* fix: missing quote in table name

* fix quote

* fix tests

* Update vectorstore.py

* Update vectorstore.py

* Update vectorstore.py
  • Loading branch information
averikitsch committed Apr 30, 2024
1 parent 008a995 commit b490c81
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/langchain_google_cloud_sql_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def __query_collection(
search_function = self.distance_strategy.search_function

filter = f"WHERE {filter}" if filter else ""
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM {self.table_name} {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
stmt = f"SELECT *, {search_function}({self.embedding_column}, '{embedding}') as distance FROM \"{self.table_name}\" {filter} ORDER BY {self.embedding_column} {operator} '{embedding}' LIMIT {k};"
if self.index_query_options:
await self.engine._aexecute(
f"SET LOCAL {self.index_query_options.to_string()};"
Expand Down Expand Up @@ -742,7 +742,7 @@ async def aapply_vector_index(
params = "WITH " + index.index_options()
function = index.distance_strategy.index_function
name = name or index.name
stmt = f"CREATE INDEX {'CONCURRENTLY' if concurrently else ''} {name} ON {self.table_name} USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};"
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {name} ON "{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};'
if concurrently:
await self.engine._aexecute_outside_tx(stmt)
else:
Expand Down
34 changes: 17 additions & 17 deletions tests/test_cloudsql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore

DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync" + str(uuid.uuid4())
CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4())
VECTOR_SIZE = 768

Expand Down Expand Up @@ -94,7 +94,7 @@ def vs_sync(self, engine_sync):
table_name=DEFAULT_TABLE_SYNC,
)
yield vs
engine_sync._execute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}")
engine_sync._execute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"')
engine_sync._engine.dispose()

@pytest_asyncio.fixture(scope="class")
Expand All @@ -106,7 +106,7 @@ async def vs(self, engine):
table_name=DEFAULT_TABLE,
)
yield vs
await engine._aexecute(f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
await engine._aexecute(f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
await engine._engine.dispose()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -149,45 +149,45 @@ async def test_post_init(self, engine):
async def test_aadd_texts(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3

ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, metadatas, ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 6
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_texts_edge_cases(self, engine, vs):
texts = ["Taylor's", '"Swift"', "best-friend"]
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_docs(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_documents(docs, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_aadd_embedding(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs._aadd_embeddings(texts, embeddings, metadatas, ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
await engine._aexecute(f"TRUNCATE TABLE {DEFAULT_TABLE}")
await engine._aexecute(f'TRUNCATE TABLE "{DEFAULT_TABLE}"')

async def test_adelete(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_texts(texts, ids=ids)
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 3
# delete an ID
await vs.adelete([ids[0]])
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
results = await engine._afetch(f'SELECT * FROM "{DEFAULT_TABLE}"')
assert len(results) == 2

async def test_aadd_texts_custom(self, engine, vs_custom):
Expand Down Expand Up @@ -249,13 +249,13 @@ async def test_adelete_custom(self, engine, vs_custom):
async def test_add_docs(self, engine_sync, vs_sync):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
vs_sync.add_documents(docs, ids=ids)
results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"')
assert len(results) == 3

async def test_add_texts(self, engine_sync, vs_sync):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
vs_sync.add_texts(texts, ids=ids)
results = engine_sync._fetch(f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
results = engine_sync._fetch(f'SELECT * FROM "{DEFAULT_TABLE_SYNC}"')
assert len(results) == 6

# Need tests for store metadata=False

0 comments on commit b490c81

Please sign in to comment.