ums.management.db
1# Agenten Plattform 2# 3# (c) 2024 Magnus Bender 4# Institute of Humanities-Centered Artificial Intelligence (CHAI) 5# Universitaet Hamburg 6# https://www.chai.uni-hamburg.de/~bender 7# 8# source code released under the terms of GNU Public License Version 3 9# https://www.gnu.org/licenses/gpl-3.0.txt 10 11import os 12import sqlite3, atexit 13 14from datetime import datetime 15from threading import Lock 16from typing import Generator 17 18from pydantic import validate_call, ValidationError 19 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow 21 22class DB(): 23 24 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 25 26 def __init__(self): 27 self.db = sqlite3.connect( 28 os.path.join(PERSIST_PATH, 'messages.db'), 29 check_same_thread=False 30 ) 31 self.db.row_factory = sqlite3.Row 32 atexit.register(lambda db : db.close(), self.db) 33 34 self.db_lock = Lock() 35 36 self._assure_tables() 37 38 def _assure_tables(self): 39 self.db_lock.acquire() 40 with self.db: 41 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 42 count INTEGER PRIMARY KEY AUTOINCREMENT, 43 id TEXT, 44 sender TEXT, 45 recipient TEXT, 46 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 47 json BLOB, 48 processed BOOL DEFAULT FALSE, 49 solution BOOL DEFAULT NULL 50 )""") 51 self.db_lock.release() 52 53 @validate_call 54 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 55 self.db_lock.acquire() 56 with self.db: 57 self.db.execute( 58 """INSERT INTO Messages ( 59 id, sender, recipient, json, processed 60 ) VALUES ( 61 :id, :sender, :recipient, :json, :processed 62 )""", { 63 "id" : message.id, 64 "sender" : sender, 65 "recipient" : recipient, 66 "json" : message.model_dump_json(), 67 "processed" : processed 68 }) 69 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 70 self.db_lock.release() 71 72 return new_count['last'] 73 74 @validate_call 75 def set_processed(self, count:int, processed:bool=True) -> bool: 76 self.db_lock.acquire() 77 with self.db: 78 try: 79 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 80 return True 81 except: 82 return False 83 finally: 84 self.db_lock.release() 85 86 @validate_call 87 def set_solution(self, count:int, solution:bool) -> bool: 88 self.db_lock.acquire() 89 with self.db: 90 try: 91 self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count)) 92 return True 93 except: 94 return False 95 finally: 96 self.db_lock.release() 97 98 def __iter__(self) -> Generator[MessageDbRow, None, None]: 99 yield from self.iterate() 100 101 @validate_call 102 def iterate(self, 103 id:str|None=None, sender:str|None=None, recipient:str|None=None, 104 processed:bool|None=None, solution:bool|None=None, 105 time_after:int|None=None, time_before:int|None=None, 106 limit:int=20, offset:int=0, _count_only:bool=False 107 ) -> Generator[MessageDbRow|int, None, None]: 108 109 where = [] 110 params = { 111 "lim": limit, 112 "off": offset 113 } 114 115 for v,n in ( 116 (id,'id'), 117 (sender,'sender'), (recipient,'recipient'), 118 (processed,'processed'), (solution,'solution') 119 ): 120 if not v is None: 121 where.append('{} = :{}'.format(n,n)) 122 params[n] = v 123 124 if time_after: 125 where.append("time > :t_after") 126 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 127 128 if time_before: 129 where.append("time < :t_before") 130 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 131 132 if len(where) > 0: 133 where_clause = "WHERE " + (' AND '.join(where)) 134 else: 135 where_clause = "" 136 137 with self.db: 138 if _count_only: 139 count = self.db.execute( 140 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 141 params 142 ).fetchone() 143 144 yield count['count'] 145 else: 146 for row in self.db.execute( 147 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 148 params 149 ): 150 yield self._create_row_object(row, allow_lazy=True) 151 152 def __len__(self) -> int: 153 return self.len() 154 155 def len(self, **kwargs) -> int: 156 """ 157 See `DB.iterate` for possible values of `kwargs`. 158 """ 159 kwargs['_count_only'] = True 160 return next(self.iterate(**kwargs)) 161 162 def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow: 163 try: 164 message = AgentMessage.model_validate_json( 165 row['json'], 166 context={"require_file_exists": not allow_lazy} 167 ) 168 except ValidationError as e: 169 if allow_lazy: 170 message = AgentMessage( 171 id="error", 172 riddle={"context":str(e),"question":"Failed to load from Database!"} 173 ) 174 else: 175 raise e 176 177 return MessageDbRow( 178 count=row['count'], 179 sender=row['sender'], 180 recipient=row['recipient'], 181 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 182 message=message, 183 processed=row['processed'], 184 solution=row['solution'] 185 ) 186 187 def by_count(self, count:int) -> MessageDbRow|None: 188 with self.db: 189 try: 190 return self._create_row_object( 191 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 192 ) 193 except: 194 return None 195
class
DB:
23class DB(): 24 25 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 26 27 def __init__(self): 28 self.db = sqlite3.connect( 29 os.path.join(PERSIST_PATH, 'messages.db'), 30 check_same_thread=False 31 ) 32 self.db.row_factory = sqlite3.Row 33 atexit.register(lambda db : db.close(), self.db) 34 35 self.db_lock = Lock() 36 37 self._assure_tables() 38 39 def _assure_tables(self): 40 self.db_lock.acquire() 41 with self.db: 42 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 43 count INTEGER PRIMARY KEY AUTOINCREMENT, 44 id TEXT, 45 sender TEXT, 46 recipient TEXT, 47 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 48 json BLOB, 49 processed BOOL DEFAULT FALSE, 50 solution BOOL DEFAULT NULL 51 )""") 52 self.db_lock.release() 53 54 @validate_call 55 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 56 self.db_lock.acquire() 57 with self.db: 58 self.db.execute( 59 """INSERT INTO Messages ( 60 id, sender, recipient, json, processed 61 ) VALUES ( 62 :id, :sender, :recipient, :json, :processed 63 )""", { 64 "id" : message.id, 65 "sender" : sender, 66 "recipient" : recipient, 67 "json" : message.model_dump_json(), 68 "processed" : processed 69 }) 70 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 71 self.db_lock.release() 72 73 return new_count['last'] 74 75 @validate_call 76 def set_processed(self, count:int, processed:bool=True) -> bool: 77 self.db_lock.acquire() 78 with self.db: 79 try: 80 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 81 return True 82 except: 83 return False 84 finally: 85 self.db_lock.release() 86 87 @validate_call 88 def set_solution(self, count:int, solution:bool) -> bool: 89 self.db_lock.acquire() 90 with self.db: 91 try: 92 self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count)) 93 return True 94 except: 95 return False 96 finally: 97 self.db_lock.release() 98 99 def __iter__(self) -> Generator[MessageDbRow, None, None]: 100 yield from self.iterate() 101 102 @validate_call 103 def iterate(self, 104 id:str|None=None, sender:str|None=None, recipient:str|None=None, 105 processed:bool|None=None, solution:bool|None=None, 106 time_after:int|None=None, time_before:int|None=None, 107 limit:int=20, offset:int=0, _count_only:bool=False 108 ) -> Generator[MessageDbRow|int, None, None]: 109 110 where = [] 111 params = { 112 "lim": limit, 113 "off": offset 114 } 115 116 for v,n in ( 117 (id,'id'), 118 (sender,'sender'), (recipient,'recipient'), 119 (processed,'processed'), (solution,'solution') 120 ): 121 if not v is None: 122 where.append('{} = :{}'.format(n,n)) 123 params[n] = v 124 125 if time_after: 126 where.append("time > :t_after") 127 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 128 129 if time_before: 130 where.append("time < :t_before") 131 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 132 133 if len(where) > 0: 134 where_clause = "WHERE " + (' AND '.join(where)) 135 else: 136 where_clause = "" 137 138 with self.db: 139 if _count_only: 140 count = self.db.execute( 141 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 142 params 143 ).fetchone() 144 145 yield count['count'] 146 else: 147 for row in self.db.execute( 148 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 149 params 150 ): 151 yield self._create_row_object(row, allow_lazy=True) 152 153 def __len__(self) -> int: 154 return self.len() 155 156 def len(self, **kwargs) -> int: 157 """ 158 See `DB.iterate` for possible values of `kwargs`. 159 """ 160 kwargs['_count_only'] = True 161 return next(self.iterate(**kwargs)) 162 163 def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow: 164 try: 165 message = AgentMessage.model_validate_json( 166 row['json'], 167 context={"require_file_exists": not allow_lazy} 168 ) 169 except ValidationError as e: 170 if allow_lazy: 171 message = AgentMessage( 172 id="error", 173 riddle={"context":str(e),"question":"Failed to load from Database!"} 174 ) 175 else: 176 raise e 177 178 return MessageDbRow( 179 count=row['count'], 180 sender=row['sender'], 181 recipient=row['recipient'], 182 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 183 message=message, 184 processed=row['processed'], 185 solution=row['solution'] 186 ) 187 188 def by_count(self, count:int) -> MessageDbRow|None: 189 with self.db: 190 try: 191 return self._create_row_object( 192 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 193 ) 194 except: 195 return None
@validate_call
def
add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
54 @validate_call 55 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 56 self.db_lock.acquire() 57 with self.db: 58 self.db.execute( 59 """INSERT INTO Messages ( 60 id, sender, recipient, json, processed 61 ) VALUES ( 62 :id, :sender, :recipient, :json, :processed 63 )""", { 64 "id" : message.id, 65 "sender" : sender, 66 "recipient" : recipient, 67 "json" : message.model_dump_json(), 68 "processed" : processed 69 }) 70 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 71 self.db_lock.release() 72 73 return new_count['last']
@validate_call
def
set_processed(self, count: int, processed: bool = True) -> bool:
75 @validate_call 76 def set_processed(self, count:int, processed:bool=True) -> bool: 77 self.db_lock.acquire() 78 with self.db: 79 try: 80 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 81 return True 82 except: 83 return False 84 finally: 85 self.db_lock.release()
@validate_call
def
iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, solution: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
102 @validate_call 103 def iterate(self, 104 id:str|None=None, sender:str|None=None, recipient:str|None=None, 105 processed:bool|None=None, solution:bool|None=None, 106 time_after:int|None=None, time_before:int|None=None, 107 limit:int=20, offset:int=0, _count_only:bool=False 108 ) -> Generator[MessageDbRow|int, None, None]: 109 110 where = [] 111 params = { 112 "lim": limit, 113 "off": offset 114 } 115 116 for v,n in ( 117 (id,'id'), 118 (sender,'sender'), (recipient,'recipient'), 119 (processed,'processed'), (solution,'solution') 120 ): 121 if not v is None: 122 where.append('{} = :{}'.format(n,n)) 123 params[n] = v 124 125 if time_after: 126 where.append("time > :t_after") 127 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 128 129 if time_before: 130 where.append("time < :t_before") 131 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 132 133 if len(where) > 0: 134 where_clause = "WHERE " + (' AND '.join(where)) 135 else: 136 where_clause = "" 137 138 with self.db: 139 if _count_only: 140 count = self.db.execute( 141 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 142 params 143 ).fetchone() 144 145 yield count['count'] 146 else: 147 for row in self.db.execute( 148 "SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause), 149 params 150 ): 151 yield self._create_row_object(row, allow_lazy=True)
def
len(self, **kwargs) -> int:
156 def len(self, **kwargs) -> int: 157 """ 158 See `DB.iterate` for possible values of `kwargs`. 159 """ 160 kwargs['_count_only'] = True 161 return next(self.iterate(**kwargs))
See DB.iterate
for possible values of kwargs
.