Skip to content

Commit

Permalink
feat: Add CloudSQL Postgresql chatmessagehistory with Integration Tes…
Browse files Browse the repository at this point in the history
…ts (#23)

* CloudSQL Postgresql chatmessagehistory with Integration Tests

* chore: fix requested changes

* Update test_postgresql_chatmessagehistory.py

* Update postgresql_chat_message_history.py

* Update postgresql_engine.py

* Update test_postgresql_chatmessagehistory.py

* clean up

* lint

* Update test_postgresql_chatmessagehistory.py

* Update test_postgresql_chatmessagehistory.py

---------

Co-authored-by: Averi Kitsch <akitsch@google.com>
  • Loading branch information
yashdeepkumar-searce and averikitsch committed Feb 15, 2024
1 parent a1e9cb6 commit 3ab9d4e
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/langchain_google_cloud_sql_pg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# limitations under the License.

from langchain_google_cloud_sql_pg.cloudsql_vectorstore import CloudSQLVectorStore
from langchain_google_cloud_sql_pg.postgresql_chat_message_history import (
PostgreSQLChatMessageHistory,
)
from langchain_google_cloud_sql_pg.postgresql_engine import Column, PostgreSQLEngine

__all__ = ["PostgreSQLEngine", "Column", "CloudSQLVectorStore"]
__all__ = [
"PostgreSQLEngine",
"Column",
"CloudSQLVectorStore",
"PostgreSQLChatMessageHistory",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
from typing import List

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict

from langchain_google_cloud_sql_pg.postgresql_engine import PostgreSQLEngine


class PostgreSQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Postgres database."""

def __init__(self, engine: PostgreSQLEngine, session_id: str, table_name: str):
self.engine = engine
self.session_id = session_id
self.table_name = table_name

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from PostgreSQL"""
query = f"""SELECT data, type FROM "{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
results = self.engine.run_as_sync(
self.engine._afetch(query, {"session_id": self.session_id})
)
if not results:
return []

items = [{"data": result["data"], "type": result["type"]} for result in results]
messages = messages_from_dict(items)
return messages

async def aadd_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type)
VALUES (:session_id, :data, :type);
"""
await self.engine._aexecute(
query,
{
"session_id": self.session_id,
"data": json.dumps(message.dict()),
"type": message.type,
},
)

def add_message(self, message: BaseMessage) -> None:
self.engine.run_as_sync(self.aadd_message(message))

async def aclear(self) -> None:
"""Clear session memory from PostgreSQL"""
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;"""
await self.engine._aexecute(query, {"session_id": self.session_id})

def clear(self) -> None:
self.engine.run_as_sync(self.aclear())
17 changes: 13 additions & 4 deletions src/langchain_google_cloud_sql_pg/postgresql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ async def afrom_instance(
password,
)

async def _aexecute(self, query: str):
async def _aexecute(self, query: str, params: Optional[dict] = None):
"""Execute a SQL query."""
async with self._engine.connect() as conn:
await conn.execute(text(query))
await conn.execute(text(query), params)
await conn.commit()

async def _aexecute_outside_tx(self, query: str):
Expand All @@ -200,10 +200,10 @@ async def _aexecute_outside_tx(self, query: str):
await conn.execute(text("COMMIT"))
await conn.execute(text(query))

async def _afetch(self, query: str):
async def _afetch(self, query: str, params: Optional[dict] = None):
async with self._engine.connect() as conn:
"""Fetch results from a SQL query."""
result = await conn.execute(text(query))
result = await conn.execute(text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()

Expand Down Expand Up @@ -244,3 +244,12 @@ async def init_vectorstore_table(
query += "\n);"

await self._aexecute(query)

async def init_chat_history_table(self, table_name) -> None:
create_table_query = f"""CREATE TABLE IF NOT EXISTS "{table_name}"(
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
data JSONB NOT NULL,
type TEXT NOT NULL
);"""
await self._aexecute(create_table_query)
61 changes: 61 additions & 0 deletions tests/test_postgresql_chatmessagehistory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import uuid
from typing import Generator

import pytest
import sqlalchemy
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage

from langchain_google_cloud_sql_pg import PostgreSQLChatMessageHistory, PostgreSQLEngine

project_id = os.environ["PROJECT_ID"]
region = os.environ["REGION"]
instance_id = os.environ["INSTANCE_ID"]
db_name = os.environ["DATABASE_ID"]
table_name = "message_store_test" + str(uuid.uuid4())

@pytest.fixture(name="memory_engine")
def setup() -> Generator:
engine = PostgreSQLEngine.from_instance(
project_id=project_id,
region=region,
instance=instance_id,
database=db_name,
)
engine.run_as_sync(engine.init_chat_history_table(table_name=table_name))
yield engine


def test_chat_message_history(memory_engine: PostgreSQLEngine) -> None:
history = PostgreSQLChatMessageHistory(
engine=memory_engine, session_id="test", table_name=table_name
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
messages = history.messages

# verify messages are correct
assert messages[0].content == "hi!"
assert type(messages[0]) is HumanMessage
assert messages[1].content == "whats up?"
assert type(messages[1]) is AIMessage

# verify clear() clears message history
history.clear()
assert len(history.messages) == 0
memory_engine.run_as_sync(memory_engine._aexecute(f'DROP TABLE IF EXISTS "{table_name}"'))

0 comments on commit 3ab9d4e

Please sign in to comment.