ums.management.db

  1# Agenten Plattform
  2#
  3# (c) 2024 Magnus Bender
  4# 	Institute of Humanities-Centered Artificial Intelligence (CHAI)
  5# 	Universitaet Hamburg
  6# 	https://www.chai.uni-hamburg.de/~bender
  7#  
  8# source code released under the terms of GNU Public License Version 3
  9# https://www.gnu.org/licenses/gpl-3.0.txt
 10
 11import os
 12import sqlite3, atexit
 13
 14from datetime import datetime
 15from threading import Lock
 16from typing import Generator
 17
 18from pydantic import validate_call, ValidationError
 19
 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
 21
 22class DB():
 23
 24	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 25	
 26	def __init__(self):
 27		self.db = sqlite3.connect(
 28			os.path.join(PERSIST_PATH, 'messages.db'),
 29			check_same_thread=False
 30		)
 31		self.db.row_factory = sqlite3.Row
 32		atexit.register(lambda db : db.close(), self.db)
 33
 34		self.db_lock = Lock()
 35
 36		self._assure_tables()
 37
 38	def _assure_tables(self):
 39		self.db_lock.acquire()
 40		with self.db:
 41			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 42				count INTEGER PRIMARY KEY AUTOINCREMENT,
 43				id TEXT, 
 44				sender TEXT,
 45			 	recipient TEXT,
 46				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 47				json BLOB,
 48				processed BOOL DEFAULT FALSE,
 49				solution BOOL DEFAULT NULL
 50			)""")
 51		self.db_lock.release()
 52
 53	@validate_call
 54	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
 55		self.db_lock.acquire()
 56		with self.db:
 57			self.db.execute(
 58				"""INSERT INTO Messages (
 59						id, sender, recipient, json, processed
 60					) VALUES (
 61						:id, :sender, :recipient, :json, :processed
 62				)""", {
 63					"id" : message.id,
 64					"sender" : sender,
 65					"recipient" : recipient,
 66					"json" : message.model_dump_json(),
 67					"processed" : processed
 68				})
 69			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 70		self.db_lock.release()
 71
 72		return new_count['last']
 73
 74	@validate_call
 75	def set_processed(self, count:int, processed:bool=True) -> bool:
 76		self.db_lock.acquire()
 77		with self.db:
 78			try:
 79				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 80				return True
 81			except:
 82				return False
 83			finally:
 84				self.db_lock.release()
 85
 86	@validate_call
 87	def set_solution(self, count:int, solution:bool) -> bool:
 88		self.db_lock.acquire()
 89		with self.db:
 90			try:
 91				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
 92				return True
 93			except:
 94				return False
 95			finally:
 96				self.db_lock.release()
 97
 98	def __iter__(self) -> Generator[MessageDbRow, None, None]:
 99		yield from self.iterate()
100
101	@validate_call
102	def iterate(self,
103		id:str|None=None, sender:str|None=None, recipient:str|None=None,
104		processed:bool|None=None, solution:bool|None=None,
105		time_after:int|None=None, time_before:int|None=None,
106		limit:int=20, offset:int=0, _count_only:bool=False
107	) -> Generator[MessageDbRow|int, None, None]:
108		
109		where = []
110		params = {
111			"lim": limit,
112			"off": offset
113		}
114
115		for v,n in (
116				(id,'id'),
117				(sender,'sender'), (recipient,'recipient'),
118				(processed,'processed'), (solution,'solution')
119			):
120			if not v is None:
121				where.append('{} = :{}'.format(n,n))
122				params[n] = v
123
124		if time_after:
125			where.append("time > :t_after")
126			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
127
128		if time_before:
129			where.append("time < :t_before")
130			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
131
132		if len(where) > 0:
133			where_clause = "WHERE " + (' AND '.join(where))
134		else:
135			where_clause = ""
136
137		with self.db:
138			if _count_only:
139				count = self.db.execute(
140					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
141					params
142				).fetchone()
143
144				yield count['count']
145			else:
146				for row in self.db.execute(
147					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
148					params
149				):
150					yield self._create_row_object(row, allow_lazy=True)
151
152	def __len__(self) -> int:
153		return self.len()
154
155	def len(self, **kwargs) -> int:
156		"""
157			See `DB.iterate` for possible values of `kwargs`.
158		"""
159		kwargs['_count_only'] = True
160		return next(self.iterate(**kwargs))
161
162	def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
163		try:
164			message = AgentMessage.model_validate_json(
165					row['json'],
166					context={"require_file_exists": not allow_lazy}
167				)
168		except ValidationError as e:
169			if allow_lazy:
170				message = AgentMessage(
171					id="error",
172					riddle={"context":str(e),"question":"Failed to load from Database!"}
173				) 
174			else:
175				raise e
176		
177		return MessageDbRow(
178			count=row['count'],
179			sender=row['sender'],
180			recipient=row['recipient'],
181			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
182			message=message,
183			processed=row['processed'],
184			solution=row['solution']
185		)
186
187	def by_count(self, count:int) -> MessageDbRow|None:
188		with self.db:
189			try:
190				return self._create_row_object(
191					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
192				)
193			except:
194				return None
195	
class DB:
 23class DB():
 24
 25	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 26	
 27	def __init__(self):
 28		self.db = sqlite3.connect(
 29			os.path.join(PERSIST_PATH, 'messages.db'),
 30			check_same_thread=False
 31		)
 32		self.db.row_factory = sqlite3.Row
 33		atexit.register(lambda db : db.close(), self.db)
 34
 35		self.db_lock = Lock()
 36
 37		self._assure_tables()
 38
 39	def _assure_tables(self):
 40		self.db_lock.acquire()
 41		with self.db:
 42			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 43				count INTEGER PRIMARY KEY AUTOINCREMENT,
 44				id TEXT, 
 45				sender TEXT,
 46			 	recipient TEXT,
 47				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 48				json BLOB,
 49				processed BOOL DEFAULT FALSE,
 50				solution BOOL DEFAULT NULL
 51			)""")
 52		self.db_lock.release()
 53
 54	@validate_call
 55	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
 56		self.db_lock.acquire()
 57		with self.db:
 58			self.db.execute(
 59				"""INSERT INTO Messages (
 60						id, sender, recipient, json, processed
 61					) VALUES (
 62						:id, :sender, :recipient, :json, :processed
 63				)""", {
 64					"id" : message.id,
 65					"sender" : sender,
 66					"recipient" : recipient,
 67					"json" : message.model_dump_json(),
 68					"processed" : processed
 69				})
 70			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 71		self.db_lock.release()
 72
 73		return new_count['last']
 74
 75	@validate_call
 76	def set_processed(self, count:int, processed:bool=True) -> bool:
 77		self.db_lock.acquire()
 78		with self.db:
 79			try:
 80				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 81				return True
 82			except:
 83				return False
 84			finally:
 85				self.db_lock.release()
 86
 87	@validate_call
 88	def set_solution(self, count:int, solution:bool) -> bool:
 89		self.db_lock.acquire()
 90		with self.db:
 91			try:
 92				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
 93				return True
 94			except:
 95				return False
 96			finally:
 97				self.db_lock.release()
 98
 99	def __iter__(self) -> Generator[MessageDbRow, None, None]:
100		yield from self.iterate()
101
102	@validate_call
103	def iterate(self,
104		id:str|None=None, sender:str|None=None, recipient:str|None=None,
105		processed:bool|None=None, solution:bool|None=None,
106		time_after:int|None=None, time_before:int|None=None,
107		limit:int=20, offset:int=0, _count_only:bool=False
108	) -> Generator[MessageDbRow|int, None, None]:
109		
110		where = []
111		params = {
112			"lim": limit,
113			"off": offset
114		}
115
116		for v,n in (
117				(id,'id'),
118				(sender,'sender'), (recipient,'recipient'),
119				(processed,'processed'), (solution,'solution')
120			):
121			if not v is None:
122				where.append('{} = :{}'.format(n,n))
123				params[n] = v
124
125		if time_after:
126			where.append("time > :t_after")
127			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
128
129		if time_before:
130			where.append("time < :t_before")
131			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
132
133		if len(where) > 0:
134			where_clause = "WHERE " + (' AND '.join(where))
135		else:
136			where_clause = ""
137
138		with self.db:
139			if _count_only:
140				count = self.db.execute(
141					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
142					params
143				).fetchone()
144
145				yield count['count']
146			else:
147				for row in self.db.execute(
148					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
149					params
150				):
151					yield self._create_row_object(row, allow_lazy=True)
152
153	def __len__(self) -> int:
154		return self.len()
155
156	def len(self, **kwargs) -> int:
157		"""
158			See `DB.iterate` for possible values of `kwargs`.
159		"""
160		kwargs['_count_only'] = True
161		return next(self.iterate(**kwargs))
162
163	def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
164		try:
165			message = AgentMessage.model_validate_json(
166					row['json'],
167					context={"require_file_exists": not allow_lazy}
168				)
169		except ValidationError as e:
170			if allow_lazy:
171				message = AgentMessage(
172					id="error",
173					riddle={"context":str(e),"question":"Failed to load from Database!"}
174				) 
175			else:
176				raise e
177		
178		return MessageDbRow(
179			count=row['count'],
180			sender=row['sender'],
181			recipient=row['recipient'],
182			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
183			message=message,
184			processed=row['processed'],
185			solution=row['solution']
186		)
187
188	def by_count(self, count:int) -> MessageDbRow|None:
189		with self.db:
190			try:
191				return self._create_row_object(
192					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
193				)
194			except:
195				return None
db
db_lock
@validate_call
def add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
54	@validate_call
55	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
56		self.db_lock.acquire()
57		with self.db:
58			self.db.execute(
59				"""INSERT INTO Messages (
60						id, sender, recipient, json, processed
61					) VALUES (
62						:id, :sender, :recipient, :json, :processed
63				)""", {
64					"id" : message.id,
65					"sender" : sender,
66					"recipient" : recipient,
67					"json" : message.model_dump_json(),
68					"processed" : processed
69				})
70			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
71		self.db_lock.release()
72
73		return new_count['last']
@validate_call
def set_processed(self, count: int, processed: bool = True) -> bool:
75	@validate_call
76	def set_processed(self, count:int, processed:bool=True) -> bool:
77		self.db_lock.acquire()
78		with self.db:
79			try:
80				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
81				return True
82			except:
83				return False
84			finally:
85				self.db_lock.release()
@validate_call
def set_solution(self, count: int, solution: bool) -> bool:
87	@validate_call
88	def set_solution(self, count:int, solution:bool) -> bool:
89		self.db_lock.acquire()
90		with self.db:
91			try:
92				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
93				return True
94			except:
95				return False
96			finally:
97				self.db_lock.release()
@validate_call
def iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, solution: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
102	@validate_call
103	def iterate(self,
104		id:str|None=None, sender:str|None=None, recipient:str|None=None,
105		processed:bool|None=None, solution:bool|None=None,
106		time_after:int|None=None, time_before:int|None=None,
107		limit:int=20, offset:int=0, _count_only:bool=False
108	) -> Generator[MessageDbRow|int, None, None]:
109		
110		where = []
111		params = {
112			"lim": limit,
113			"off": offset
114		}
115
116		for v,n in (
117				(id,'id'),
118				(sender,'sender'), (recipient,'recipient'),
119				(processed,'processed'), (solution,'solution')
120			):
121			if not v is None:
122				where.append('{} = :{}'.format(n,n))
123				params[n] = v
124
125		if time_after:
126			where.append("time > :t_after")
127			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
128
129		if time_before:
130			where.append("time < :t_before")
131			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
132
133		if len(where) > 0:
134			where_clause = "WHERE " + (' AND '.join(where))
135		else:
136			where_clause = ""
137
138		with self.db:
139			if _count_only:
140				count = self.db.execute(
141					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
142					params
143				).fetchone()
144
145				yield count['count']
146			else:
147				for row in self.db.execute(
148					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
149					params
150				):
151					yield self._create_row_object(row, allow_lazy=True)
def len(self, **kwargs) -> int:
156	def len(self, **kwargs) -> int:
157		"""
158			See `DB.iterate` for possible values of `kwargs`.
159		"""
160		kwargs['_count_only'] = True
161		return next(self.iterate(**kwargs))

See DB.iterate for possible values of kwargs.

def by_count(self, count: int) -> ums.utils.types.MessageDbRow | None:
188	def by_count(self, count:int) -> MessageDbRow|None:
189		with self.db:
190			try:
191				return self._create_row_object(
192					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
193				)
194			except:
195				return None