ums.management.db
1# Agenten Plattform 2# 3# (c) 2024 Magnus Bender 4# Institute of Humanities-Centered Artificial Intelligence (CHAI) 5# Universitaet Hamburg 6# https://www.chai.uni-hamburg.de/~bender 7# 8# source code released under the terms of GNU Public License Version 3 9# https://www.gnu.org/licenses/gpl-3.0.txt 10 11import os 12import sqlite3, atexit 13 14from datetime import datetime 15from threading import Lock 16from typing import Generator 17 18from pydantic import validate_call 19 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow 21 22class DB(): 23 24 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 25 26 def __init__(self): 27 self.db = sqlite3.connect( 28 os.path.join(PERSIST_PATH, 'messages.db'), 29 check_same_thread=False 30 ) 31 self.db.row_factory = sqlite3.Row 32 atexit.register(lambda db : db.close(), self.db) 33 34 self.db_lock = Lock() 35 36 self._assure_tables() 37 38 def _assure_tables(self): 39 self.db_lock.acquire() 40 with self.db: 41 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 42 count INTEGER PRIMARY KEY AUTOINCREMENT, 43 id TEXT, 44 sender TEXT, 45 recipient TEXT, 46 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 47 json BLOB, 48 processed BOOL DEFAULT FALSE 49 )""") 50 self.db_lock.release() 51 52 @validate_call 53 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int: 54 self.db_lock.acquire() 55 with self.db: 56 self.db.execute( 57 """INSERT INTO Messages ( 58 id, sender, recipient, json, processed 59 ) VALUES ( 60 :id, :sender, :recipient, :json, :processed 61 )""", { 62 "id" : message.id, 63 "sender" : sender, 64 "recipient" : recipient, 65 "json" : message.model_dump_json(), 66 "processed" : processed 67 }) 68 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 69 self.db_lock.release() 70 71 return new_count['last'] 72 73 @validate_call 74 def set_processed(self, count:int, processed:bool=True) -> bool: 75 self.db_lock.acquire() 76 with self.db: 77 try: 78 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 79 return True 80 except: 81 return False 82 finally: 83 self.db_lock.release() 84 85 def __iter__(self) -> Generator[MessageDbRow, None, None]: 86 yield from self.iterate() 87 88 @validate_call 89 def iterate(self, 90 id:str|None=None, sender:str|None=None, recipient:str|None=None, 91 processed:bool|None=None, 92 time_after:int|None=None, time_before:int|None=None, 93 limit:int=20, offset:int=0, _count_only:bool=False 94 ) -> Generator[MessageDbRow|int, None, None]: 95 96 where = [] 97 params = { 98 "lim": limit, 99 "off": offset 100 } 101 102 for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')): 103 if not v is None: 104 where.append('{} = :{}'.format(n,n)) 105 params[n] = v 106 107 if time_after: 108 where.append("time > :t_after") 109 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 110 111 if time_before: 112 where.append("time < :t_before") 113 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 114 115 if len(where) > 0: 116 where_clause = "WHERE " + (' AND '.join(where)) 117 else: 118 where_clause = "" 119 120 with self.db: 121 if _count_only: 122 count = self.db.execute( 123 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 124 params 125 ).fetchone() 126 127 yield count['count'] 128 else: 129 for row in self.db.execute( 130 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 131 params 132 ): 133 yield self._create_row_object(row) 134 135 def __len__(self) -> int: 136 return self.len() 137 138 def len(self, **kwargs) -> int: 139 """ 140 See `DB.iterate` for possible values of `kwargs`. 141 """ 142 kwargs['_count_only'] = True 143 return next(self.iterate(**kwargs)) 144 145 def _create_row_object(self, row:sqlite3.Row) -> MessageDbRow: 146 return MessageDbRow( 147 count=row['count'], 148 sender=row['sender'], 149 recipient=row['recipient'], 150 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 151 message=AgentMessage.model_validate_json(row['json']), 152 processed=row['processed'] 153 ) 154 155 def by_count(self, count:int) -> MessageDbRow|None: 156 with self.db: 157 try: 158 return self._create_row_object( 159 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 160 ) 161 except: 162 return None 163
class
DB:
23class DB(): 24 25 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 26 27 def __init__(self): 28 self.db = sqlite3.connect( 29 os.path.join(PERSIST_PATH, 'messages.db'), 30 check_same_thread=False 31 ) 32 self.db.row_factory = sqlite3.Row 33 atexit.register(lambda db : db.close(), self.db) 34 35 self.db_lock = Lock() 36 37 self._assure_tables() 38 39 def _assure_tables(self): 40 self.db_lock.acquire() 41 with self.db: 42 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 43 count INTEGER PRIMARY KEY AUTOINCREMENT, 44 id TEXT, 45 sender TEXT, 46 recipient TEXT, 47 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 48 json BLOB, 49 processed BOOL DEFAULT FALSE 50 )""") 51 self.db_lock.release() 52 53 @validate_call 54 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int: 55 self.db_lock.acquire() 56 with self.db: 57 self.db.execute( 58 """INSERT INTO Messages ( 59 id, sender, recipient, json, processed 60 ) VALUES ( 61 :id, :sender, :recipient, :json, :processed 62 )""", { 63 "id" : message.id, 64 "sender" : sender, 65 "recipient" : recipient, 66 "json" : message.model_dump_json(), 67 "processed" : processed 68 }) 69 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 70 self.db_lock.release() 71 72 return new_count['last'] 73 74 @validate_call 75 def set_processed(self, count:int, processed:bool=True) -> bool: 76 self.db_lock.acquire() 77 with self.db: 78 try: 79 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 80 return True 81 except: 82 return False 83 finally: 84 self.db_lock.release() 85 86 def __iter__(self) -> Generator[MessageDbRow, None, None]: 87 yield from self.iterate() 88 89 @validate_call 90 def iterate(self, 91 id:str|None=None, sender:str|None=None, recipient:str|None=None, 92 processed:bool|None=None, 93 time_after:int|None=None, time_before:int|None=None, 94 limit:int=20, offset:int=0, _count_only:bool=False 95 ) -> Generator[MessageDbRow|int, None, None]: 96 97 where = [] 98 params = { 99 "lim": limit, 100 "off": offset 101 } 102 103 for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')): 104 if not v is None: 105 where.append('{} = :{}'.format(n,n)) 106 params[n] = v 107 108 if time_after: 109 where.append("time > :t_after") 110 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 111 112 if time_before: 113 where.append("time < :t_before") 114 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 115 116 if len(where) > 0: 117 where_clause = "WHERE " + (' AND '.join(where)) 118 else: 119 where_clause = "" 120 121 with self.db: 122 if _count_only: 123 count = self.db.execute( 124 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 125 params 126 ).fetchone() 127 128 yield count['count'] 129 else: 130 for row in self.db.execute( 131 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 132 params 133 ): 134 yield self._create_row_object(row) 135 136 def __len__(self) -> int: 137 return self.len() 138 139 def len(self, **kwargs) -> int: 140 """ 141 See `DB.iterate` for possible values of `kwargs`. 142 """ 143 kwargs['_count_only'] = True 144 return next(self.iterate(**kwargs)) 145 146 def _create_row_object(self, row:sqlite3.Row) -> MessageDbRow: 147 return MessageDbRow( 148 count=row['count'], 149 sender=row['sender'], 150 recipient=row['recipient'], 151 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 152 message=AgentMessage.model_validate_json(row['json']), 153 processed=row['processed'] 154 ) 155 156 def by_count(self, count:int) -> MessageDbRow|None: 157 with self.db: 158 try: 159 return self._create_row_object( 160 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 161 ) 162 except: 163 return None
@validate_call
def
add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
53 @validate_call 54 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int: 55 self.db_lock.acquire() 56 with self.db: 57 self.db.execute( 58 """INSERT INTO Messages ( 59 id, sender, recipient, json, processed 60 ) VALUES ( 61 :id, :sender, :recipient, :json, :processed 62 )""", { 63 "id" : message.id, 64 "sender" : sender, 65 "recipient" : recipient, 66 "json" : message.model_dump_json(), 67 "processed" : processed 68 }) 69 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 70 self.db_lock.release() 71 72 return new_count['last']
@validate_call
def
set_processed(self, count: int, processed: bool = True) -> bool:
74 @validate_call 75 def set_processed(self, count:int, processed:bool=True) -> bool: 76 self.db_lock.acquire() 77 with self.db: 78 try: 79 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 80 return True 81 except: 82 return False 83 finally: 84 self.db_lock.release()
@validate_call
def
iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
89 @validate_call 90 def iterate(self, 91 id:str|None=None, sender:str|None=None, recipient:str|None=None, 92 processed:bool|None=None, 93 time_after:int|None=None, time_before:int|None=None, 94 limit:int=20, offset:int=0, _count_only:bool=False 95 ) -> Generator[MessageDbRow|int, None, None]: 96 97 where = [] 98 params = { 99 "lim": limit, 100 "off": offset 101 } 102 103 for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')): 104 if not v is None: 105 where.append('{} = :{}'.format(n,n)) 106 params[n] = v 107 108 if time_after: 109 where.append("time > :t_after") 110 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 111 112 if time_before: 113 where.append("time < :t_before") 114 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 115 116 if len(where) > 0: 117 where_clause = "WHERE " + (' AND '.join(where)) 118 else: 119 where_clause = "" 120 121 with self.db: 122 if _count_only: 123 count = self.db.execute( 124 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 125 params 126 ).fetchone() 127 128 yield count['count'] 129 else: 130 for row in self.db.execute( 131 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 132 params 133 ): 134 yield self._create_row_object(row)
def
len(self, **kwargs) -> int:
139 def len(self, **kwargs) -> int: 140 """ 141 See `DB.iterate` for possible values of `kwargs`. 142 """ 143 kwargs['_count_only'] = True 144 return next(self.iterate(**kwargs))
See DB.iterate
for possible values of kwargs
.