Management Message Processing

This commit is contained in:
2024-10-29 16:47:58 +01:00
parent 04ccd488f8
commit fac784e013
6 changed files with 182 additions and 21 deletions

View File

@ -15,7 +15,7 @@ from datetime import datetime
from threading import Lock
from typing import Generator
from pydantic import validate_call
from pydantic import validate_call, ValidationError
from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
@ -45,19 +45,20 @@ class DB():
recipient TEXT,
time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
json BLOB,
processed BOOL DEFAULT FALSE
processed BOOL DEFAULT FALSE,
solution BOOL DEFAULT NULL
)""")
self.db_lock.release()
@validate_call
def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
self.db_lock.acquire()
with self.db:
self.db.execute(
"""INSERT INTO Messages (
id, sender, recipient, json, processed
id, sender, recipient, json, processed
) VALUES (
:id, :sender, :recipient, :json, :processed
:id, :sender, :recipient, :json, :processed
)""", {
"id" : message.id,
"sender" : sender,
@ -82,13 +83,25 @@ class DB():
finally:
self.db_lock.release()
@validate_call
def set_solution(self, count:int, solution:bool) -> bool:
self.db_lock.acquire()
with self.db:
try:
self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
return True
except:
return False
finally:
self.db_lock.release()
def __iter__(self) -> Generator[MessageDbRow, None, None]:
yield from self.iterate()
@validate_call
def iterate(self,
id:str|None=None, sender:str|None=None, recipient:str|None=None,
processed:bool|None=None,
processed:bool|None=None, solution:bool|None=None,
time_after:int|None=None, time_before:int|None=None,
limit:int=20, offset:int=0, _count_only:bool=False
) -> Generator[MessageDbRow|int, None, None]:
@ -99,7 +112,11 @@ class DB():
"off": offset
}
for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
for v,n in (
(id,'id'),
(sender,'sender'), (recipient,'recipient'),
(processed,'processed'), (solution,'solution')
):
if not v is None:
where.append('{} = :{}'.format(n,n))
params[n] = v
@ -130,7 +147,7 @@ class DB():
"SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause),
params
):
yield self._create_row_object(row)
yield self._create_row_object(row, allow_lazy=True)
def __len__(self) -> int:
return self.len()
@ -142,14 +159,29 @@ class DB():
kwargs['_count_only'] = True
return next(self.iterate(**kwargs))
def _create_row_object(self, row:sqlite3.Row) -> MessageDbRow:
def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
try:
message = AgentMessage.model_validate_json(
row['json'],
context={"require_file_exists": not allow_lazy}
)
except ValidationError as e:
if allow_lazy:
message = AgentMessage(
id="error",
riddle={"context":str(e),"question":"Failed to load from Database!"}
)
else:
raise e
return MessageDbRow(
count=row['count'],
sender=row['sender'],
recipient=row['recipient'],
time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
message=AgentMessage.model_validate_json(row['json'], context={"require_file_exists":False}),
processed=row['processed']
message=message,
processed=row['processed'],
solution=row['solution']
)
def by_count(self, count:int) -> MessageDbRow|None: