ums.management.db

  1# Agenten Plattform
  2#
  3# (c) 2024 Magnus Bender
  4# 	Institute of Humanities-Centered Artificial Intelligence (CHAI)
  5# 	Universitaet Hamburg
  6# 	https://www.chai.uni-hamburg.de/~bender
  7#  
  8# source code released under the terms of GNU Public License Version 3
  9# https://www.gnu.org/licenses/gpl-3.0.txt
 10
 11import os
 12import sqlite3, atexit
 13
 14from datetime import datetime
 15from threading import Lock
 16from typing import Generator
 17
 18from pydantic import validate_call
 19
 20from ums.utils import PERSIST_PATH, AgentMessage, MessageDbRow
 21
 22class DB():
 23
 24	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 25	
 26	def __init__(self):
 27		self.db = sqlite3.connect(
 28			os.path.join(PERSIST_PATH, 'messages.db'),
 29			check_same_thread=False
 30		)
 31		self.db.row_factory = sqlite3.Row
 32		atexit.register(lambda db : db.close(), self.db)
 33
 34		self.db_lock = Lock()
 35
 36		self._assure_tables()
 37
 38	def _assure_tables(self):
 39		self.db_lock.acquire()
 40		with self.db:
 41			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 42				count INTEGER PRIMARY KEY AUTOINCREMENT,
 43				id TEXT, 
 44				sender TEXT,
 45			 	recipient TEXT,
 46				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 47				json BLOB,
 48				processed BOOL DEFAULT FALSE
 49			)""")
 50		self.db_lock.release()
 51
 52	@validate_call
 53	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
 54		self.db_lock.acquire()
 55		with self.db:
 56			self.db.execute(
 57				"""INSERT INTO Messages (
 58						id, sender, recipient, json, processed 
 59					) VALUES (
 60						:id, :sender, :recipient, :json, :processed 
 61				)""", {
 62					"id" : message.id,
 63					"sender" : sender,
 64					"recipient" : recipient,
 65					"json" : message.model_dump_json(),
 66					"processed" : processed
 67				})
 68			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 69		self.db_lock.release()
 70
 71		return new_count['last']
 72
 73	@validate_call
 74	def set_processed(self, count:int, processed:bool=True) -> bool:
 75		self.db_lock.acquire()
 76		with self.db:
 77			try:
 78				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 79				return True
 80			except:
 81				return False
 82			finally:
 83				self.db_lock.release()
 84
 85	def __iter__(self) -> Generator[MessageDbRow, None, None]:
 86		yield from self.iterate()
 87
 88	@validate_call
 89	def iterate(self,
 90		id:str|None=None, sender:str|None=None, recipient:str|None=None,
 91		processed:bool|None=None,
 92		time_after:int|None=None, time_before:int|None=None,
 93		limit:int=20, offset:int=0, _count_only:bool=False
 94	) -> Generator[MessageDbRow|int, None, None]:
 95		
 96		where = []
 97		params = {
 98			"lim": limit,
 99			"off": offset
100		}
101
102		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
103			if not v is None:
104				where.append('{} = :{}'.format(n,n))
105				params[n] = v
106
107		if time_after:
108			where.append("time > :t_after")
109			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
110
111		if time_before:
112			where.append("time < :t_before")
113			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
114
115		if len(where) > 0:
116			where_clause = "WHERE " + (' AND '.join(where))
117		else:
118			where_clause = ""
119
120		with self.db:
121			if _count_only:
122				count = self.db.execute(
123					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
124					params
125				).fetchone()
126
127				yield count['count']
128			else:
129				for row in self.db.execute(
130					"SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause),
131					params
132				):
133					yield self._create_row_object(row)
134
135	def __len__(self) -> int:
136		return self.len()
137
138	def len(self, **kwargs) -> int:
139		"""
140			See `DB.iterate` for possible values of `kwargs`.
141		"""
142		kwargs['_count_only'] = True
143		return next(self.iterate(**kwargs))
144
145	def _create_row_object(self, row:sqlite3.Row) -> MessageDbRow:
146		return MessageDbRow(
147			count=row['count'],
148			sender=row['sender'],
149			recipient=row['recipient'],
150			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
151			message=AgentMessage.model_validate_json(row['json']),
152			processed=row['processed']
153		)
154
155	def by_count(self, count:int) -> MessageDbRow|None:
156		with self.db:
157			try:
158				return self._create_row_object(
159					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
160				)
161			except:
162				return None
163	
class DB:
 23class DB():
 24
 25	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 26	
 27	def __init__(self):
 28		self.db = sqlite3.connect(
 29			os.path.join(PERSIST_PATH, 'messages.db'),
 30			check_same_thread=False
 31		)
 32		self.db.row_factory = sqlite3.Row
 33		atexit.register(lambda db : db.close(), self.db)
 34
 35		self.db_lock = Lock()
 36
 37		self._assure_tables()
 38
 39	def _assure_tables(self):
 40		self.db_lock.acquire()
 41		with self.db:
 42			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 43				count INTEGER PRIMARY KEY AUTOINCREMENT,
 44				id TEXT, 
 45				sender TEXT,
 46			 	recipient TEXT,
 47				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 48				json BLOB,
 49				processed BOOL DEFAULT FALSE
 50			)""")
 51		self.db_lock.release()
 52
 53	@validate_call
 54	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
 55		self.db_lock.acquire()
 56		with self.db:
 57			self.db.execute(
 58				"""INSERT INTO Messages (
 59						id, sender, recipient, json, processed 
 60					) VALUES (
 61						:id, :sender, :recipient, :json, :processed 
 62				)""", {
 63					"id" : message.id,
 64					"sender" : sender,
 65					"recipient" : recipient,
 66					"json" : message.model_dump_json(),
 67					"processed" : processed
 68				})
 69			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
 70		self.db_lock.release()
 71
 72		return new_count['last']
 73
 74	@validate_call
 75	def set_processed(self, count:int, processed:bool=True) -> bool:
 76		self.db_lock.acquire()
 77		with self.db:
 78			try:
 79				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
 80				return True
 81			except:
 82				return False
 83			finally:
 84				self.db_lock.release()
 85
 86	def __iter__(self) -> Generator[MessageDbRow, None, None]:
 87		yield from self.iterate()
 88
 89	@validate_call
 90	def iterate(self,
 91		id:str|None=None, sender:str|None=None, recipient:str|None=None,
 92		processed:bool|None=None,
 93		time_after:int|None=None, time_before:int|None=None,
 94		limit:int=20, offset:int=0, _count_only:bool=False
 95	) -> Generator[MessageDbRow|int, None, None]:
 96		
 97		where = []
 98		params = {
 99			"lim": limit,
100			"off": offset
101		}
102
103		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
104			if not v is None:
105				where.append('{} = :{}'.format(n,n))
106				params[n] = v
107
108		if time_after:
109			where.append("time > :t_after")
110			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
111
112		if time_before:
113			where.append("time < :t_before")
114			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
115
116		if len(where) > 0:
117			where_clause = "WHERE " + (' AND '.join(where))
118		else:
119			where_clause = ""
120
121		with self.db:
122			if _count_only:
123				count = self.db.execute(
124					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
125					params
126				).fetchone()
127
128				yield count['count']
129			else:
130				for row in self.db.execute(
131					"SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause),
132					params
133				):
134					yield self._create_row_object(row)
135
136	def __len__(self) -> int:
137		return self.len()
138
139	def len(self, **kwargs) -> int:
140		"""
141			See `DB.iterate` for possible values of `kwargs`.
142		"""
143		kwargs['_count_only'] = True
144		return next(self.iterate(**kwargs))
145
146	def _create_row_object(self, row:sqlite3.Row) -> MessageDbRow:
147		return MessageDbRow(
148			count=row['count'],
149			sender=row['sender'],
150			recipient=row['recipient'],
151			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
152			message=AgentMessage.model_validate_json(row['json']),
153			processed=row['processed']
154		)
155
156	def by_count(self, count:int) -> MessageDbRow|None:
157		with self.db:
158			try:
159				return self._create_row_object(
160					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
161				)
162			except:
163				return None
db
db_lock
@validate_call
def add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
53	@validate_call
54	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
55		self.db_lock.acquire()
56		with self.db:
57			self.db.execute(
58				"""INSERT INTO Messages (
59						id, sender, recipient, json, processed 
60					) VALUES (
61						:id, :sender, :recipient, :json, :processed 
62				)""", {
63					"id" : message.id,
64					"sender" : sender,
65					"recipient" : recipient,
66					"json" : message.model_dump_json(),
67					"processed" : processed
68				})
69			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
70		self.db_lock.release()
71
72		return new_count['last']
@validate_call
def set_processed(self, count: int, processed: bool = True) -> bool:
74	@validate_call
75	def set_processed(self, count:int, processed:bool=True) -> bool:
76		self.db_lock.acquire()
77		with self.db:
78			try:
79				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
80				return True
81			except:
82				return False
83			finally:
84				self.db_lock.release()
@validate_call
def iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0, _count_only: bool = False) -> Generator[ums.utils.types.MessageDbRow | int, NoneType, NoneType]:
 89	@validate_call
 90	def iterate(self,
 91		id:str|None=None, sender:str|None=None, recipient:str|None=None,
 92		processed:bool|None=None,
 93		time_after:int|None=None, time_before:int|None=None,
 94		limit:int=20, offset:int=0, _count_only:bool=False
 95	) -> Generator[MessageDbRow|int, None, None]:
 96		
 97		where = []
 98		params = {
 99			"lim": limit,
100			"off": offset
101		}
102
103		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
104			if not v is None:
105				where.append('{} = :{}'.format(n,n))
106				params[n] = v
107
108		if time_after:
109			where.append("time > :t_after")
110			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
111
112		if time_before:
113			where.append("time < :t_before")
114			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
115
116		if len(where) > 0:
117			where_clause = "WHERE " + (' AND '.join(where))
118		else:
119			where_clause = ""
120
121		with self.db:
122			if _count_only:
123				count = self.db.execute(
124					"SELECT COUNT(*) as count FROM Messages {}".format(where_clause),
125					params
126				).fetchone()
127
128				yield count['count']
129			else:
130				for row in self.db.execute(
131					"SELECT * FROM Messages {} ORDER BY time DESC LIMIT :lim OFFSET :off".format(where_clause),
132					params
133				):
134					yield self._create_row_object(row)
def len(self, **kwargs) -> int:
139	def len(self, **kwargs) -> int:
140		"""
141			See `DB.iterate` for possible values of `kwargs`.
142		"""
143		kwargs['_count_only'] = True
144		return next(self.iterate(**kwargs))

See DB.iterate for possible values of kwargs.

def by_count(self, count: int) -> ums.utils.types.MessageDbRow | None:
156	def by_count(self, count:int) -> MessageDbRow|None:
157		with self.db:
158			try:
159				return self._create_row_object(
160					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
161				)
162			except:
163				return None