ums.management.db

  1# Agenten Plattform
  2#
  3# (c) 2024 Magnus Bender
  4# 	Institute of Humanities-Centered Artificial Intelligence (CHAI)
  5# 	Universitaet Hamburg
  6# 	https://www.chai.uni-hamburg.de/~bender
  7#  
  8# source code released under the terms of GNU Public License Version 3
  9# https://www.gnu.org/licenses/gpl-3.0.txt
 10
 11import os
 12import sqlite3, atexit
 13
 14from datetime import datetime
 15from threading import Lock
 16from typing import Generator
 17
 18from pydantic import validate_call, ValidationError
 19
 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
 21
 22class DB():
 23
 24	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 25	
 26	def __init__(self):
 27		self.db = sqlite3.connect(
 28			os.path.join(PERSIST_PATH, 'messages.db'),
 29			check_same_thread=False,
 30			autocommit=False
 31		)
 32		self.db.row_factory = sqlite3.Row
 33		atexit.register(lambda db : db.close(), self.db)
 34
 35		self.db_lock = Lock()
 36
 37		self._assure_tables()
 38
 39	def _assure_tables(self):
 40		self.db_lock.acquire()
 41		with self.db:
 42			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 43				count INTEGER PRIMARY KEY AUTOINCREMENT,
 44				id TEXT, 
 45				sender TEXT,
 46			 	recipient TEXT,
 47				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 48				json BLOB,
 49				processed BOOL DEFAULT FALSE,
 50				solution BOOL DEFAULT NULL
 51			)""")
 52		self.db_lock.release()
 53
 54	@validate_call
 55	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
 56		self.db_lock.acquire()
 57		with self.db:
 58			self.db.execute(
 59				"""INSERT INTO Messages (
 60						id, sender, recipient, json, processed
 61					) VALUES (
 62						:id, :sender, :recipient, :json, :processed
 63				)""", {
 64					"id" : message.id,
 65					"sender" : sender,
 66					"recipient" : recipient,
 67					"json" : message.model_dump_json(),
 68					"processed" : processed
 69				})
 70			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 71		self.db_lock.release()
 72
 73		return new_count['last']
 74
 75	@validate_call
 76	def set_processed(self, count:int, processed:bool=True) -> bool:
 77		self.db_lock.acquire()
 78		with self.db:
 79			try:
 80				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 81				return True
 82			except:
 83				return False
 84			finally:
 85				self.db_lock.release()
 86
 87	@validate_call
 88	def set_solution(self, count:int, solution:bool) -> bool:
 89		self.db_lock.acquire()
 90		with self.db:
 91			try:
 92				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
 93				return True
 94			except:
 95				return False
 96			finally:
 97				self.db_lock.release()
 98
 99	def __iter__(self) -> Generator[MessageDbRow, None, None]:
100		yield from self.iterate()
101
102	@validate_call
103	def iterate(self,
104		id:str|None=None, sender:str|None=None, recipient:str|None=None,
105		processed:bool|None=None, solution:bool|None=None,
106		time_after:int|None=None, time_before:int|None=None,
107		limit:int=20, offset:int=0, _count_only:bool=False
108	) -> Generator[MessageDbRow|int, None, None]:
109		
110		where = []
111		params = {
112			"lim": limit,
113			"off": offset
114		}
115
116		for v,n in (
117				(id,'id'),
118				(sender,'sender'), (recipient,'recipient'),
119				(processed,'processed'), (solution,'solution')
120			):
121			if not v is None:
122				where.append('{} = :{}'.format(n,n))
123				params[n] = v
124
125		if time_after:
126			where.append("time > :t_after")
127			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
128
129		if time_before:
130			where.append("time < :t_before")
131			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
132
133		if len(where) > 0:
134			where_clause = "WHERE " + (' AND '.join(where))
135		else:
136			where_clause = ""
137
138		with self.db:
139			if _count_only:
140				count = self.db.execute(
141					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
142					params
143				).fetchone()
144
145				yield count['count']
146			else:
147				for row in self.db.execute(
148					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
149					params
150				):
151					yield self._create_row_object(row, allow_lazy=True)
152
153	def __len__(self) -> int:
154		return self.len()
155
156	def len(self, **kwargs) -> int:
157		"""
158			See `DB.iterate` for possible values of `kwargs`.
159		"""
160		kwargs['_count_only'] = True
161		return next(self.iterate(**kwargs))
162
163	def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
164		try:
165			message = AgentMessage.model_validate_json(
166					row['json'],
167					context={"require_file_exists": not allow_lazy}
168				)
169		except ValidationError as e:
170			if allow_lazy:
171				message = AgentMessage(
172					id="error",
173					riddle={"context":str(e),"question":"Failed to load from Database!"}
174				) 
175			else:
176				raise e
177		
178		return MessageDbRow(
179			count=row['count'],
180			sender=row['sender'],
181			recipient=row['recipient'],
182			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
183			message=message,
184			processed=row['processed'],
185			solution=row['solution']
186		)
187
188	def by_count(self, count:int) -> MessageDbRow|None:
189		with self.db:
190			try:
191				return self._create_row_object(
192					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
193				)
194			except:
195				return None
class DB:
 23class DB():
 24
 25	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 26	
 27	def __init__(self):
 28		self.db = sqlite3.connect(
 29			os.path.join(PERSIST_PATH, 'messages.db'),
 30			check_same_thread=False,
 31			autocommit=False
 32		)
 33		self.db.row_factory = sqlite3.Row
 34		atexit.register(lambda db : db.close(), self.db)
 35
 36		self.db_lock = Lock()
 37
 38		self._assure_tables()
 39
 40	def _assure_tables(self):
 41		self.db_lock.acquire()
 42		with self.db:
 43			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 44				count INTEGER PRIMARY KEY AUTOINCREMENT,
 45				id TEXT, 
 46				sender TEXT,
 47			 	recipient TEXT,
 48				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 49				json BLOB,
 50				processed BOOL DEFAULT FALSE,
 51				solution BOOL DEFAULT NULL
 52			)""")
 53		self.db_lock.release()
 54
 55	@validate_call
 56	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
 57		self.db_lock.acquire()
 58		with self.db:
 59			self.db.execute(
 60				"""INSERT INTO Messages (
 61						id, sender, recipient, json, processed
 62					) VALUES (
 63						:id, :sender, :recipient, :json, :processed
 64				)""", {
 65					"id" : message.id,
 66					"sender" : sender,
 67					"recipient" : recipient,
 68					"json" : message.model_dump_json(),
 69					"processed" : processed
 70				})
 71			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 72		self.db_lock.release()
 73
 74		return new_count['last']
 75
 76	@validate_call
 77	def set_processed(self, count:int, processed:bool=True) -> bool:
 78		self.db_lock.acquire()
 79		with self.db:
 80			try:
 81				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 82				return True
 83			except:
 84				return False
 85			finally:
 86				self.db_lock.release()
 87
 88	@validate_call
 89	def set_solution(self, count:int, solution:bool) -> bool:
 90		self.db_lock.acquire()
 91		with self.db:
 92			try:
 93				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
 94				return True
 95			except:
 96				return False
 97			finally:
 98				self.db_lock.release()
 99
100	def __iter__(self) -> Generator[MessageDbRow, None, None]:
101		yield from self.iterate()
102
103	@validate_call
104	def iterate(self,
105		id:str|None=None, sender:str|None=None, recipient:str|None=None,
106		processed:bool|None=None, solution:bool|None=None,
107		time_after:int|None=None, time_before:int|None=None,
108		limit:int=20, offset:int=0, _count_only:bool=False
109	) -> Generator[MessageDbRow|int, None, None]:
110		
111		where = []
112		params = {
113			"lim": limit,
114			"off": offset
115		}
116
117		for v,n in (
118				(id,'id'),
119				(sender,'sender'), (recipient,'recipient'),
120				(processed,'processed'), (solution,'solution')
121			):
122			if not v is None:
123				where.append('{} = :{}'.format(n,n))
124				params[n] = v
125
126		if time_after:
127			where.append("time > :t_after")
128			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
129
130		if time_before:
131			where.append("time < :t_before")
132			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
133
134		if len(where) > 0:
135			where_clause = "WHERE " + (' AND '.join(where))
136		else:
137			where_clause = ""
138
139		with self.db:
140			if _count_only:
141				count = self.db.execute(
142					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
143					params
144				).fetchone()
145
146				yield count['count']
147			else:
148				for row in self.db.execute(
149					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
150					params
151				):
152					yield self._create_row_object(row, allow_lazy=True)
153
154	def __len__(self) -> int:
155		return self.len()
156
157	def len(self, **kwargs) -> int:
158		"""
159			See `DB.iterate` for possible values of `kwargs`.
160		"""
161		kwargs['_count_only'] = True
162		return next(self.iterate(**kwargs))
163
164	def _create_row_object(self, row:sqlite3.Row, allow_lazy:bool=True) -> MessageDbRow:
165		try:
166			message = AgentMessage.model_validate_json(
167					row['json'],
168					context={"require_file_exists": not allow_lazy}
169				)
170		except ValidationError as e:
171			if allow_lazy:
172				message = AgentMessage(
173					id="error",
174					riddle={"context":str(e),"question":"Failed to load from Database!"}
175				) 
176			else:
177				raise e
178		
179		return MessageDbRow(
180			count=row['count'],
181			sender=row['sender'],
182			recipient=row['recipient'],
183			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
184			message=message,
185			processed=row['processed'],
186			solution=row['solution']
187		)
188
189	def by_count(self, count:int) -> MessageDbRow|None:
190		with self.db:
191			try:
192				return self._create_row_object(
193					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
194				)
195			except:
196				return None
db
db_lock
@validate_call
def add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
55	@validate_call
56	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False ) -> int:
57		self.db_lock.acquire()
58		with self.db:
59			self.db.execute(
60				"""INSERT INTO Messages (
61						id, sender, recipient, json, processed
62					) VALUES (
63						:id, :sender, :recipient, :json, :processed
64				)""", {
65					"id" : message.id,
66					"sender" : sender,
67					"recipient" : recipient,
68					"json" : message.model_dump_json(),
69					"processed" : processed
70				})
71			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
72		self.db_lock.release()
73
74		return new_count['last']
@validate_call
def set_processed(self, count: int, processed: bool = True) -> bool:
76	@validate_call
77	def set_processed(self, count:int, processed:bool=True) -> bool:
78		self.db_lock.acquire()
79		with self.db:
80			try:
81				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
82				return True
83			except:
84				return False
85			finally:
86				self.db_lock.release()
@validate_call
def set_solution(self, count: int, solution: bool) -> bool:
88	@validate_call
89	def set_solution(self, count:int, solution:bool) -> bool:
90		self.db_lock.acquire()
91		with self.db:
92			try:
93				self.db.execute("UPDATE Messages SET solution = ? WHERE count = ?", (solution, count))
94				return True
95			except:
96				return False
97			finally:
98				self.db_lock.release()
@validate_call
def iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, solution: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
103	@validate_call
104	def iterate(self,
105		id:str|None=None, sender:str|None=None, recipient:str|None=None,
106		processed:bool|None=None, solution:bool|None=None,
107		time_after:int|None=None, time_before:int|None=None,
108		limit:int=20, offset:int=0, _count_only:bool=False
109	) -> Generator[MessageDbRow|int, None, None]:
110		
111		where = []
112		params = {
113			"lim": limit,
114			"off": offset
115		}
116
117		for v,n in (
118				(id,'id'),
119				(sender,'sender'), (recipient,'recipient'),
120				(processed,'processed'), (solution,'solution')
121			):
122			if not v is None:
123				where.append('{} = :{}'.format(n,n))
124				params[n] = v
125
126		if time_after:
127			where.append("time > :t_after")
128			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
129
130		if time_before:
131			where.append("time < :t_before")
132			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
133
134		if len(where) > 0:
135			where_clause = "WHERE " + (' AND '.join(where))
136		else:
137			where_clause = ""
138
139		with self.db:
140			if _count_only:
141				count = self.db.execute(
142					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
143					params
144				).fetchone()
145
146				yield count['count']
147			else:
148				for row in self.db.execute(
149					"SELECT * FROM Messages {} ORDER BY time DESC, count DESC LIMIT :lim OFFSET :off".format(where_clause),
150					params
151				):
152					yield self._create_row_object(row, allow_lazy=True)
def len(self, **kwargs) -> int:
157	def len(self, **kwargs) -> int:
158		"""
159			See `DB.iterate` for possible values of `kwargs`.
160		"""
161		kwargs['_count_only'] = True
162		return next(self.iterate(**kwargs))

See DB.iterate for possible values of kwargs.

def by_count(self, count: int) -> ums.utils.types.MessageDbRow | None:
189	def by_count(self, count:int) -> MessageDbRow|None:
190		with self.db:
191			try:
192				return self._create_row_object(
193					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
194				)
195			except:
196				return None