-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add CloudSQL Postgresql chatmessagehistory with Integration Tes…
…ts (#23) * CloudSQL Postgresql chatmessagehistory with Integration Tests * chore: fix requested changes * Update test_postgresql_chatmessagehistory.py * Update postgresql_chat_message_history.py * Update postgresql_engine.py * Update test_postgresql_chatmessagehistory.py * clean up * lint * Update test_postgresql_chatmessagehistory.py * Update test_postgresql_chatmessagehistory.py --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
- Loading branch information
1 parent
a1e9cb6
commit 3ab9d4e
Showing
4 changed files
with
154 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
src/langchain_google_cloud_sql_pg/postgresql_chat_message_history.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import json | ||
from typing import List | ||
|
||
from langchain_core.chat_history import BaseChatMessageHistory | ||
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict | ||
|
||
from langchain_google_cloud_sql_pg.postgresql_engine import PostgreSQLEngine | ||
|
||
|
||
class PostgreSQLChatMessageHistory(BaseChatMessageHistory): | ||
"""Chat message history stored in a Postgres database.""" | ||
|
||
def __init__(self, engine: PostgreSQLEngine, session_id: str, table_name: str): | ||
self.engine = engine | ||
self.session_id = session_id | ||
self.table_name = table_name | ||
|
||
@property | ||
def messages(self) -> List[BaseMessage]: # type: ignore | ||
"""Retrieve the messages from PostgreSQL""" | ||
query = f"""SELECT data, type FROM "{self.table_name}" WHERE session_id = :session_id ORDER BY id;""" | ||
results = self.engine.run_as_sync( | ||
self.engine._afetch(query, {"session_id": self.session_id}) | ||
) | ||
if not results: | ||
return [] | ||
|
||
items = [{"data": result["data"], "type": result["type"]} for result in results] | ||
messages = messages_from_dict(items) | ||
return messages | ||
|
||
async def aadd_message(self, message: BaseMessage) -> None: | ||
"""Append the message to the record in PostgreSQL""" | ||
query = f"""INSERT INTO "{self.table_name}"(session_id, data, type) | ||
VALUES (:session_id, :data, :type); | ||
""" | ||
await self.engine._aexecute( | ||
query, | ||
{ | ||
"session_id": self.session_id, | ||
"data": json.dumps(message.dict()), | ||
"type": message.type, | ||
}, | ||
) | ||
|
||
def add_message(self, message: BaseMessage) -> None: | ||
self.engine.run_as_sync(self.aadd_message(message)) | ||
|
||
async def aclear(self) -> None: | ||
"""Clear session memory from PostgreSQL""" | ||
query = f"""DELETE FROM "{self.table_name}" WHERE session_id = :session_id;""" | ||
await self.engine._aexecute(query, {"session_id": self.session_id}) | ||
|
||
def clear(self) -> None: | ||
self.engine.run_as_sync(self.aclear()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import uuid | ||
from typing import Generator | ||
|
||
import pytest | ||
import sqlalchemy | ||
from langchain_core.messages.ai import AIMessage | ||
from langchain_core.messages.human import HumanMessage | ||
|
||
from langchain_google_cloud_sql_pg import PostgreSQLChatMessageHistory, PostgreSQLEngine | ||
|
||
project_id = os.environ["PROJECT_ID"] | ||
region = os.environ["REGION"] | ||
instance_id = os.environ["INSTANCE_ID"] | ||
db_name = os.environ["DATABASE_ID"] | ||
table_name = "message_store_test" + str(uuid.uuid4()) | ||
|
||
@pytest.fixture(name="memory_engine") | ||
def setup() -> Generator: | ||
engine = PostgreSQLEngine.from_instance( | ||
project_id=project_id, | ||
region=region, | ||
instance=instance_id, | ||
database=db_name, | ||
) | ||
engine.run_as_sync(engine.init_chat_history_table(table_name=table_name)) | ||
yield engine | ||
|
||
|
||
def test_chat_message_history(memory_engine: PostgreSQLEngine) -> None: | ||
history = PostgreSQLChatMessageHistory( | ||
engine=memory_engine, session_id="test", table_name=table_name | ||
) | ||
history.add_user_message("hi!") | ||
history.add_ai_message("whats up?") | ||
messages = history.messages | ||
|
||
# verify messages are correct | ||
assert messages[0].content == "hi!" | ||
assert type(messages[0]) is HumanMessage | ||
assert messages[1].content == "whats up?" | ||
assert type(messages[1]) is AIMessage | ||
|
||
# verify clear() clears message history | ||
history.clear() | ||
assert len(history.messages) == 0 | ||
memory_engine.run_as_sync(memory_engine._aexecute(f'DROP TABLE IF EXISTS "{table_name}"')) |