ums.management.db
1# Agenten Plattform 2# 3# (c) 2024 Magnus Bender 4# Institute of Humanities-Centered Artificial Intelligence (CHAI) 5# Universitaet Hamburg 6# https://www.chai.uni-hamburg.de/~bender 7# 8# source code released under the terms of GNU Public License Version 3 9# https://www.gnu.org/licenses/gpl-3.0.txt 10 11import os 12import sqlite3, atexit 13 14from datetime import datetime 15from threading import Lock 16from typing import Generator 17 18from pydantic import validate_call, ValidationError 19 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow 21 22class DB(): 23 24 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 25 26 def __init__(self): 27 self.db = sqlite3.connect( 28 os.path.join(PERSIST_PATH, 'messages.db'), 29 check_same_thread=False, 30 autocommit=False 31 ) 32 self.db.row_factory = sqlite3.Row 33 atexit.register(lambda db : db.close(), self.db) 34 35 self.db_lock = Lock() 36 37 self._assure_tables() 38 39 def _assure_tables(self): 40 self.db_lock.acquire() 41 with self.db: 42 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 43 count INTEGER PRIMARY KEY AUTOINCREMENT, 44 id TEXT, 45 sender TEXT, 46 recipient TEXT, 47 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 48 json BLOB, 49 processed BOOL DEFAULT FALSE, 50 solution BOOL DEFAULT NULL 51 )""") 52 self.db_lock.release() 53 54 @validate_call 55 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 56 self.db_lock.acquire() 57 with self.db: 58 self.db.execute( 59 """INSERT INTO Messages ( 60 id, sender, recipient, json, processed 61 ) VALUES ( 62 :id, :sender, :recipient, :json, :processed 63 )""", { 64 "id" : message.id, 65 "sender" : sender, 66 "recipient" : recipient, 67 "json" : message.model_dump_json(), 68 "processed" : processed 69 }) 70 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 71 self.db_lock.release() 72 73 return new_count['last'] 74 75 @validate_call 76 def set_processed(self, count:int, processed:bool=True) -> bool: 77 self.db_lock.acquire() 78 with self.db: 79 try: 80 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 81 return True 82 except: 83 return False 84 finally: 85 self.db_lock.release() 86 87 @validate_call 88 def set_solution(self, count:int, solution:bool) -> bool: 89 self.db_lock.acquire() 90 with self.db: 91 try: 92 self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count)) 93 return True 94 except: 95 return False 96 finally: 97 self.db_lock.release() 98 99 def __iter__(self) -> Generator[MessageDbRow, None, None]: 100 yield from self.iterate() 101 102 @validate_call 103 def iterate(self, 104 id:str|None=None, sender:str|None=None, recipient:str|None=None, 105 processed:bool|None=None, solution:bool|None=None, 106 time_after:int|None=None, time_before:int|None=None, 107 limit:int=20, offset:int=0, _count_only:bool=False 108 ) -> Generator[MessageDbRow|int, None, None]: 109 110 where = [] 111 params = { 112 "lim": limit, 113 "off": offset 114 } 115 116 for v,n in ( 117 (id,'id'), 118 (sender,'sender'), (recipient,'recipient'), 119 (processed,'processed'), (solution,'solution') 120 ): 121 if not v is None: 122 where.append('{} = :{}'.format(n,n)) 123 params[n] = v 124 125 if time_after: 126 where.append("time > :t_after") 127 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 128 129 if time_before: 130 where.append("time < :t_before") 131 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 132 133 if len(where) > 0: 134 where_clause = "WHERE " + (' AND '.join(where)) 135 else: 136 where_clause = "" 137 138 with self.db: 139 if _count_only: 140 count = self.db.execute( 141 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 142 params 143 ).fetchone() 144 145 yield count['count'] 146 else: 147 for row in self.db.execute( 148 "SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause), 149 params 150 ): 151 yield self._create_row_object(row, allow_lazy=True) 152 153 def __len__(self) -> int: 154 return self.len() 155 156 def len(self, **kwargs) -> int: 157 """ 158 See `DB.iterate` for possible values of `kwargs`. 159 """ 160 kwargs['_count_only'] = True 161 return next(self.iterate(**kwargs)) 162 163 def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow: 164 try: 165 message = AgentMessage.model_validate_json( 166 row['json'], 167 context={"require_file_exists": not allow_lazy} 168 ) 169 except ValidationError as e: 170 if allow_lazy: 171 message = AgentMessage( 172 id="error", 173 riddle={"context":str(e),"question":"Failed to load from Database!"} 174 ) 175 else: 176 raise e 177 178 return MessageDbRow( 179 count=row['count'], 180 sender=row['sender'], 181 recipient=row['recipient'], 182 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 183 message=message, 184 processed=row['processed'], 185 solution=row['solution'] 186 ) 187 188 def by_count(self, count:int) -> MessageDbRow|None: 189 with self.db: 190 try: 191 return self._create_row_object( 192 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 193 ) 194 except: 195 return None
class
DB:
23class DB(): 24 25 _DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" 26 27 def __init__(self): 28 self.db = sqlite3.connect( 29 os.path.join(PERSIST_PATH, 'messages.db'), 30 check_same_thread=False, 31 autocommit=False 32 ) 33 self.db.row_factory = sqlite3.Row 34 atexit.register(lambda db : db.close(), self.db) 35 36 self.db_lock = Lock() 37 38 self._assure_tables() 39 40 def _assure_tables(self): 41 self.db_lock.acquire() 42 with self.db: 43 self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 44 count INTEGER PRIMARY KEY AUTOINCREMENT, 45 id TEXT, 46 sender TEXT, 47 recipient TEXT, 48 time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 49 json BLOB, 50 processed BOOL DEFAULT FALSE, 51 solution BOOL DEFAULT NULL 52 )""") 53 self.db_lock.release() 54 55 @validate_call 56 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 57 self.db_lock.acquire() 58 with self.db: 59 self.db.execute( 60 """INSERT INTO Messages ( 61 id, sender, recipient, json, processed 62 ) VALUES ( 63 :id, :sender, :recipient, :json, :processed 64 )""", { 65 "id" : message.id, 66 "sender" : sender, 67 "recipient" : recipient, 68 "json" : message.model_dump_json(), 69 "processed" : processed 70 }) 71 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 72 self.db_lock.release() 73 74 return new_count['last'] 75 76 @validate_call 77 def set_processed(self, count:int, processed:bool=True) -> bool: 78 self.db_lock.acquire() 79 with self.db: 80 try: 81 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 82 return True 83 except: 84 return False 85 finally: 86 self.db_lock.release() 87 88 @validate_call 89 def set_solution(self, count:int, solution:bool) -> bool: 90 self.db_lock.acquire() 91 with self.db: 92 try: 93 self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count)) 94 return True 95 except: 96 return False 97 finally: 98 self.db_lock.release() 99 100 def __iter__(self) -> Generator[MessageDbRow, None, None]: 101 yield from self.iterate() 102 103 @validate_call 104 def iterate(self, 105 id:str|None=None, sender:str|None=None, recipient:str|None=None, 106 processed:bool|None=None, solution:bool|None=None, 107 time_after:int|None=None, time_before:int|None=None, 108 limit:int=20, offset:int=0, _count_only:bool=False 109 ) -> Generator[MessageDbRow|int, None, None]: 110 111 where = [] 112 params = { 113 "lim": limit, 114 "off": offset 115 } 116 117 for v,n in ( 118 (id,'id'), 119 (sender,'sender'), (recipient,'recipient'), 120 (processed,'processed'), (solution,'solution') 121 ): 122 if not v is None: 123 where.append('{} = :{}'.format(n,n)) 124 params[n] = v 125 126 if time_after: 127 where.append("time > :t_after") 128 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 129 130 if time_before: 131 where.append("time < :t_before") 132 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 133 134 if len(where) > 0: 135 where_clause = "WHERE " + (' AND '.join(where)) 136 else: 137 where_clause = "" 138 139 with self.db: 140 if _count_only: 141 count = self.db.execute( 142 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 143 params 144 ).fetchone() 145 146 yield count['count'] 147 else: 148 for row in self.db.execute( 149 "SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause), 150 params 151 ): 152 yield self._create_row_object(row, allow_lazy=True) 153 154 def __len__(self) -> int: 155 return self.len() 156 157 def len(self, **kwargs) -> int: 158 """ 159 See `DB.iterate` for possible values of `kwargs`. 160 """ 161 kwargs['_count_only'] = True 162 return next(self.iterate(**kwargs)) 163 164 def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow: 165 try: 166 message = AgentMessage.model_validate_json( 167 row['json'], 168 context={"require_file_exists": not allow_lazy} 169 ) 170 except ValidationError as e: 171 if allow_lazy: 172 message = AgentMessage( 173 id="error", 174 riddle={"context":str(e),"question":"Failed to load from Database!"} 175 ) 176 else: 177 raise e 178 179 return MessageDbRow( 180 count=row['count'], 181 sender=row['sender'], 182 recipient=row['recipient'], 183 time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()), 184 message=message, 185 processed=row['processed'], 186 solution=row['solution'] 187 ) 188 189 def by_count(self, count:int) -> MessageDbRow|None: 190 with self.db: 191 try: 192 return self._create_row_object( 193 self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone() 194 ) 195 except: 196 return None
@validate_call
def
add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
55 @validate_call 56 def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int: 57 self.db_lock.acquire() 58 with self.db: 59 self.db.execute( 60 """INSERT INTO Messages ( 61 id, sender, recipient, json, processed 62 ) VALUES ( 63 :id, :sender, :recipient, :json, :processed 64 )""", { 65 "id" : message.id, 66 "sender" : sender, 67 "recipient" : recipient, 68 "json" : message.model_dump_json(), 69 "processed" : processed 70 }) 71 new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone() 72 self.db_lock.release() 73 74 return new_count['last']
@validate_call
def
set_processed(self, count: int, processed: bool = True) -> bool:
76 @validate_call 77 def set_processed(self, count:int, processed:bool=True) -> bool: 78 self.db_lock.acquire() 79 with self.db: 80 try: 81 self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count)) 82 return True 83 except: 84 return False 85 finally: 86 self.db_lock.release()
@validate_call
def
iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, solution: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
103 @validate_call 104 def iterate(self, 105 id:str|None=None, sender:str|None=None, recipient:str|None=None, 106 processed:bool|None=None, solution:bool|None=None, 107 time_after:int|None=None, time_before:int|None=None, 108 limit:int=20, offset:int=0, _count_only:bool=False 109 ) -> Generator[MessageDbRow|int, None, None]: 110 111 where = [] 112 params = { 113 "lim": limit, 114 "off": offset 115 } 116 117 for v,n in ( 118 (id,'id'), 119 (sender,'sender'), (recipient,'recipient'), 120 (processed,'processed'), (solution,'solution') 121 ): 122 if not v is None: 123 where.append('{} = :{}'.format(n,n)) 124 params[n] = v 125 126 if time_after: 127 where.append("time > :t_after") 128 params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT) 129 130 if time_before: 131 where.append("time < :t_before") 132 params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT) 133 134 if len(where) > 0: 135 where_clause = "WHERE " + (' AND '.join(where)) 136 else: 137 where_clause = "" 138 139 with self.db: 140 if _count_only: 141 count = self.db.execute( 142 "SELECT COUNT(*) as count FROM Messages {}".format(where_clause), 143 params 144 ).fetchone() 145 146 yield count['count'] 147 else: 148 for row in self.db.execute( 149 "SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause), 150 params 151 ): 152 yield self._create_row_object(row, allow_lazy=True)
def
len(self, **kwargs) -> int:
157 def len(self, **kwargs) -> int: 158 """ 159 See `DB.iterate` for possible values of `kwargs`. 160 """ 161 kwargs['_count_only'] = True 162 return next(self.iterate(**kwargs))
See DB.iterate
for possible values of kwargs
.