All checks were successful
Build and push Docker images on git tags / build (push) Successful in 40m1s
195 lines
5.0 KiB
Python
195 lines
5.0 KiB
Python
# Agenten Plattform
|
|
#
|
|
# (c) 2024 Magnus Bender
|
|
# Institute of Humanities-Centered Artificial Intelligence (CHAI)
|
|
# Universitaet Hamburg
|
|
# https://www.chai.uni-hamburg.de/~bender
|
|
#
|
|
# source code released under the terms of GNU Public License Version 3
|
|
# https://www.gnu.org/licenses/gpl-3.0.txt
|
|
|
|
import os
|
|
import sqlite3, atexit
|
|
|
|
from datetime import datetime
|
|
from threading import Lock
|
|
from typing import Generator
|
|
|
|
from pydantic import validate_call, ValidationError
|
|
|
|
from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
|
|
|
|
class DB():
|
|
|
|
_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
|
|
def __init__(self):
|
|
self.db = sqlite3.connect(
|
|
os.path.join(PERSIST_PATH, 'messages.db'),
|
|
check_same_thread=False,
|
|
autocommit=False
|
|
)
|
|
self.db.row_factory = sqlite3.Row
|
|
atexit.register(lambda db : db.close(), self.db)
|
|
|
|
self.db_lock = Lock()
|
|
|
|
self._assure_tables()
|
|
|
|
def _assure_tables(self):
|
|
self.db_lock.acquire()
|
|
with self.db:
|
|
self.db.execute("""CREATE TABLE IF NOT EXISTS Messages (
|
|
count INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
id TEXT,
|
|
sender TEXT,
|
|
recipient TEXT,
|
|
time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
json BLOB,
|
|
processed BOOL DEFAULT FALSE,
|
|
solution BOOL DEFAULT NULL
|
|
)""")
|
|
self.db_lock.release()
|
|
|
|
@validate_call
|
|
def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
|
|
self.db_lock.acquire()
|
|
with self.db:
|
|
self.db.execute(
|
|
"""INSERT INTO Messages (
|
|
id, sender, recipient, json, processed
|
|
) VALUES (
|
|
:id, :sender, :recipient, :json, :processed
|
|
)""", {
|
|
"id" : message.id,
|
|
"sender" : sender,
|
|
"recipient" : recipient,
|
|
"json" : message.model_dump_json(),
|
|
"processed" : processed
|
|
})
|
|
new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
|
|
self.db_lock.release()
|
|
|
|
return new_count['last']
|
|
|
|
@validate_call
|
|
def set_processed(self, count:int, processed:bool=True) -> bool:
|
|
self.db_lock.acquire()
|
|
with self.db:
|
|
try:
|
|
self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
|
|
return True
|
|
except:
|
|
return False
|
|
finally:
|
|
self.db_lock.release()
|
|
|
|
@validate_call
|
|
def set_solution(self, count:int, solution:bool) -> bool:
|
|
self.db_lock.acquire()
|
|
with self.db:
|
|
try:
|
|
self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
|
|
return True
|
|
except:
|
|
return False
|
|
finally:
|
|
self.db_lock.release()
|
|
|
|
def __iter__(self) -> Generator[MessageDbRow, None, None]:
|
|
yield from self.iterate()
|
|
|
|
@validate_call
|
|
def iterate(self,
|
|
id:str|None=None, sender:str|None=None, recipient:str|None=None,
|
|
processed:bool|None=None, solution:bool|None=None,
|
|
time_after:int|None=None, time_before:int|None=None,
|
|
limit:int=20, offset:int=0, _count_only:bool=False
|
|
) -> Generator[MessageDbRow|int, None, None]:
|
|
|
|
where = []
|
|
params = {
|
|
"lim": limit,
|
|
"off": offset
|
|
}
|
|
|
|
for v,n in (
|
|
(id,'id'),
|
|
(sender,'sender'), (recipient,'recipient'),
|
|
(processed,'processed'), (solution,'solution')
|
|
):
|
|
if not v is None:
|
|
where.append('{} = :{}'.format(n,n))
|
|
params[n] = v
|
|
|
|
if time_after:
|
|
where.append("time > :t_after")
|
|
params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
|
|
|
|
if time_before:
|
|
where.append("time < :t_before")
|
|
params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
|
|
|
|
if len(where) > 0:
|
|
where_clause = "WHERE " + (' AND '.join(where))
|
|
else:
|
|
where_clause = ""
|
|
|
|
with self.db:
|
|
if _count_only:
|
|
count = self.db.execute(
|
|
"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
|
|
params
|
|
).fetchone()
|
|
|
|
yield count['count']
|
|
else:
|
|
for row in self.db.execute(
|
|
"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
|
|
params
|
|
):
|
|
yield self._create_row_object(row, allow_lazy=True)
|
|
|
|
def __len__(self) -> int:
|
|
return self.len()
|
|
|
|
def len(self, **kwargs) -> int:
|
|
"""
|
|
See `DB.iterate` for possible values of `kwargs`.
|
|
"""
|
|
kwargs['_count_only'] = True
|
|
return next(self.iterate(**kwargs))
|
|
|
|
def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
|
|
try:
|
|
message = AgentMessage.model_validate_json(
|
|
row['json'],
|
|
context={"require_file_exists": not allow_lazy}
|
|
)
|
|
except ValidationError as e:
|
|
if allow_lazy:
|
|
message = AgentMessage(
|
|
id="error",
|
|
riddle={"context":str(e),"question":"Failed to load from Database!"}
|
|
)
|
|
else:
|
|
raise e
|
|
|
|
return MessageDbRow(
|
|
count=row['count'],
|
|
sender=row['sender'],
|
|
recipient=row['recipient'],
|
|
time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
|
|
message=message,
|
|
processed=row['processed'],
|
|
solution=row['solution']
|
|
)
|
|
|
|
def by_count(self, count:int) -> MessageDbRow|None:
|
|
with self.db:
|
|
try:
|
|
return self._create_row_object(
|
|
self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
|
|
)
|
|
except:
|
|
return None |