KIMB-technologies aae167cf11
All checks were successful
Build and push Docker images on git tags / build (push) Successful in 40m1s
Fix DB issues with Python 3.12
2024-11-20 12:03:38 +01:00

195 lines
5.0 KiB
Python

# Agenten Plattform
#
# (c) 2024 Magnus Bender
# Institute of Humanities-Centered Artificial Intelligence (CHAI)
# Universitaet Hamburg
# https://www.chai.uni-hamburg.de/~bender
#
# source code released under the terms of GNU Public License Version 3
# https://www.gnu.org/licenses/gpl-3.0.txt
import os
import sqlite3, atexit
from datetime import datetime
from threading import Lock
from typing import Generator
from pydantic import validate_call, ValidationError
from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
class DB():
_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
def __init__(self):
self.db = sqlite3.connect(
os.path.join(PERSIST_PATH, 'messages.db'),
check_same_thread=False,
autocommit=False
)
self.db.row_factory = sqlite3.Row
atexit.register(lambda db : db.close(), self.db)
self.db_lock = Lock()
self._assure_tables()
def _assure_tables(self):
self.db_lock.acquire()
with self.db:
self.db.execute("""CREATE TABLE IF NOT EXISTS Messages (
count INTEGER PRIMARY KEY AUTOINCREMENT,
id TEXT,
sender TEXT,
recipient TEXT,
time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
json BLOB,
processed BOOL DEFAULT FALSE,
solution BOOL DEFAULT NULL
)""")
self.db_lock.release()
@validate_call
def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
self.db_lock.acquire()
with self.db:
self.db.execute(
"""INSERT INTO Messages (
id, sender, recipient, json, processed
) VALUES (
:id, :sender, :recipient, :json, :processed
)""", {
"id" : message.id,
"sender" : sender,
"recipient" : recipient,
"json" : message.model_dump_json(),
"processed" : processed
})
new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
self.db_lock.release()
return new_count['last']
@validate_call
def set_processed(self, count:int, processed:bool=True) -> bool:
self.db_lock.acquire()
with self.db:
try:
self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
return True
except:
return False
finally:
self.db_lock.release()
@validate_call
def set_solution(self, count:int, solution:bool) -> bool:
self.db_lock.acquire()
with self.db:
try:
self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
return True
except:
return False
finally:
self.db_lock.release()
def __iter__(self) -> Generator[MessageDbRow, None, None]:
yield from self.iterate()
@validate_call
def iterate(self,
id:str|None=None, sender:str|None=None, recipient:str|None=None,
processed:bool|None=None, solution:bool|None=None,
time_after:int|None=None, time_before:int|None=None,
limit:int=20, offset:int=0, _count_only:bool=False
) -> Generator[MessageDbRow|int, None, None]:
where = []
params = {
"lim": limit,
"off": offset
}
for v,n in (
(id,'id'),
(sender,'sender'), (recipient,'recipient'),
(processed,'processed'), (solution,'solution')
):
if not v is None:
where.append('{} = :{}'.format(n,n))
params[n] = v
if time_after:
where.append("time > :t_after")
params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
if time_before:
where.append("time < :t_before")
params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
if len(where) > 0:
where_clause = "WHERE " + (' AND '.join(where))
else:
where_clause = ""
with self.db:
if _count_only:
count = self.db.execute(
"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
params
).fetchone()
yield count['count']
else:
for row in self.db.execute(
"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
params
):
yield self._create_row_object(row, allow_lazy=True)
def __len__(self) -> int:
return self.len()
def len(self, **kwargs) -> int:
"""
See `DB.iterate` for possible values of `kwargs`.
"""
kwargs['_count_only'] = True
return next(self.iterate(**kwargs))
def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
try:
message = AgentMessage.model_validate_json(
row['json'],
context={"require_file_exists": not allow_lazy}
)
except ValidationError as e:
if allow_lazy:
message = AgentMessage(
id="error",
riddle={"context":str(e),"question":"Failed to load from Database!"}
)
else:
raise e
return MessageDbRow(
count=row['count'],
sender=row['sender'],
recipient=row['recipient'],
time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
message=message,
processed=row['processed'],
solution=row['solution']
)
def by_count(self, count:int) -> MessageDbRow|None:
with self.db:
try:
return self._create_row_object(
self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
)
except:
return None