# Agenten Plattform # # (c) 2024 Magnus Bender # Institute of Humanities-Centered Artificial Intelligence (CHAI) # Universitaet Hamburg # https://www.chai.uni-hamburg.de/~bender # # source code released under the terms of GNU Public License Version 3 # https://www.gnu.org/licenses/gpl-3.0.txt import os import sqlite3, atexit from datetime import datetime from threading import Lock from typing import Generator from pydantic import validate_call, ValidationError from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow class DB(): _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" def __init__(self): self.db = sqlite3.connect( os.path.join(PERSIST_PATH, 'messages.db'), check_same_thread=False ) self.db.row_factory = sqlite3.Row atexit.register(lambda db : db.close(), self.db) self.db_lock = Lock() self._assure_tables() def _assure_tables(self): self.db_lock.acquire() with self.db: self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( count INTEGER PRIMARY KEY AUTOINCREMENT, id TEXT, sender TEXT, recipient TEXT, time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, json BLOB, processed BOOL DEFAULT FALSE, solution BOOL DEFAULT NULL )""") self.db_lock.release() @validate_call def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: self.db_lock.acquire() with self.db: self.db.execute( """INSERT INTO Messages ( id, sender, recipient, json, processed ) VALUES ( :id, :sender, :recipient, :json, :processed )""", { "id" : message.id, "sender" : sender, "recipient" : recipient, "json" : message.model_dump_json(), "processed" : processed }) new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() self.db_lock.release() return new_count['last'] @validate_call def set_processed(self, count:int, processed:bool=True) -> bool: self.db_lock.acquire() with self.db: try: self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) return True except: return False finally: self.db_lock.release() @validate_call def set_solution(self, count:int, solution:bool) -> bool: self.db_lock.acquire() with self.db: try: self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count)) return True except: return False finally: self.db_lock.release() def __iter__(self) -> Generator[MessageDbRow, None, None]: yield from self.iterate() @validate_call def iterate(self, id:str|None=None, sender:str|None=None, recipient:str|None=None, processed:bool|None=None, solution:bool|None=None, time_after:int|None=None, time_before:int|None=None, limit:int=20, offset:int=0, _count_only:bool=False ) -> Generator[MessageDbRow|int, None, None]: where = [] params = { "lim": limit, "off": offset } for v,n in ( (id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed'), (solution,'solution') ): if not v is None: where.append('{} = :{}'.format(n,n)) params[n] = v if time_after: where.append("time > :t_after") params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) if time_before: where.append("time < :t_before") params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) if len(where) > 0: where_clause = "WHERE " + (' AND '.join(where)) else: where_clause = "" with self.db: if _count_only: count = self.db.execute( "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), params ).fetchone() yield count['count'] else: for row in self.db.execute( "SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause), params ): yield self._create_row_object(row, allow_lazy=True) def __len__(self) -> int: return self.len() def len(self, **kwargs) -> int: """ See `DB.iterate` for possible values of `kwargs`. """ kwargs['_count_only'] = True return next(self.iterate(**kwargs)) def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow: try: message = AgentMessage.model_validate_json( row['json'], context={"require_file_exists": not allow_lazy} ) except ValidationError as e: if allow_lazy: message = AgentMessage( id="error", riddle={"context":str(e),"question":"Failed to load from Database!"} ) else: raise e return MessageDbRow( count=row['count'], sender=row['sender'], recipient=row['recipient'], time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), message=message, processed=row['processed'], solution=row['solution'] ) def by_count(self, count:int) -> MessageDbRow|None: with self.db: try: return self._create_row_object( self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() ) except: return None