ums.management.db

  1# Agenten Plattform
  2#
  3# (c) 2024 Magnus Bender
  4# 	Institute of Humanities-Centered Artificial Intelligence (CHAI)
  5# 	Universitaet Hamburg
  6# 	https://www.chai.uni-hamburg.de/~bender
  7#  
  8# source code released under the terms of GNU Public License Version 3
  9# https://www.gnu.org/licenses/gpl-3.0.txt
 10
 11import os
 12import sqlite3, atexit
 13
 14from datetime import datetime
 15from threading import Lock
 16from typing import Generator
 17
 18from pydantic import validate_call, BaseModel
 19
 20from ums.utils import PERSIST_PATH, AgentMessage
 21
 22class RowObject(BaseModel):
 23	"""
 24		Object representing a database row.
 25	"""
 26
 27	count : int
 28	"""
 29		The count (primary key) of the item.
 30	"""
 31
 32	sender : str
 33	"""
 34		The sender of the message.
 35	""" 
 36
 37	recipient : str
 38	"""
 39		The recipient of the message
 40	"""
 41
 42	time : int
 43	"""
 44		The time (unix timestamp) the message was received/ sent.
 45	"""
 46
 47	message : AgentMessage
 48	"""
 49		The message  received/ sent.
 50	"""
 51
 52	processed : bool
 53	"""
 54		Did the management process the message, i.e., did the tasks necessary for this message (mostly only relevant for received messages).
 55	"""
 56
 57class DB():
 58
 59	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 60	
 61	def __init__(self):
 62		self.db = sqlite3.connect(
 63			os.path.join(PERSIST_PATH, 'messages.db'),
 64			check_same_thread=False
 65		) 
 66		self.db.row_factory = sqlite3.Row
 67
 68		self.dblock = Lock()
 69		atexit.register(lambda db : db.close(), self.db)
 70
 71		self._assure_tables()
 72
 73	def _assure_tables(self):
 74		self.dblock.acquire()
 75		with self.db:
 76			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 77				count INTEGER PRIMARY KEY AUTOINCREMENT,
 78				id TEXT, 
 79				sender TEXT,
 80			 	recipient TEXT,
 81				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 82				json BLOB,
 83				processed BOOL DEFAULT FALSE
 84			)""")
 85		self.dblock.release()
 86
 87	@validate_call
 88	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
 89		self.dblock.acquire()
 90		with self.db:
 91			self.db.execute(
 92				"""INSERT INTO Messages (
 93						id, sender, recipient, json, processed 
 94					) VALUES (
 95						:id, :sender, :recipient, :json, :processed 
 96				)""", {
 97					"id" : message.id,
 98					"sender" : sender,
 99					"recipient" : recipient,
100					"json" : message.model_dump_json(),
101					"processed" : processed
102				})
103			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
104		self.dblock.release()
105
106		return new_count['last']
107
108	@validate_call
109	def set_processed(self, count:int, processed:bool=True) -> bool:
110		self.dblock.acquire()
111		with self.db:
112			try:
113				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
114				return True
115			except:
116				return False
117		self.dblock.release()
118
119	def __iter__(self) -> Generator[RowObject, None, None]:
120		yield from self.iterate()
121
122	@validate_call
123	def iterate(self,
124		id:str|None=None, sender:str|None=None, recipient:str|None=None,
125		processed:bool|None=None,
126		time_after:int|None=None, time_before:int|None=None,
127		limit:int=20, offset:int=0
128	) -> Generator[RowObject, None, None]:
129		
130		where = []
131		params = {
132			"lim": limit,
133			"off": offset
134		}
135
136		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
137			if not v is None:
138				where.append('{} = :{}'.format(n,n))
139				params[n] = v
140
141		if time_after:
142			where.append("time > :t_after")
143			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
144
145		if time_before:
146			where.append("time < :t_before")
147			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
148
149		if len(where) > 0:
150			where_clause = "WHERE " + (' AND '.join(where))
151		else:
152			where_clause = ""
153
154		with self.db:
155			for row in self.db.execute(
156				"SELECT * FROM Messages {} LIMIT :lim OFFSET :off".format(where_clause),
157				params
158			):
159				yield self._create_row_object(row)
160
161	def _create_row_object(self, row:sqlite3.Row) -> RowObject:
162		return RowObject(
163			count=row['count'],
164			sender=row['sender'],
165			recipient=row['recipient'],
166			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
167			message=AgentMessage.model_construct(row['json']),
168			processed=row['processed']
169		)
170
171	def by_count(self, count:int) -> RowObject|None:
172		with self.db:
173			try:
174				return self._create_row_object(
175					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
176				)
177			except:
178				return None
179	
class RowObject(pydantic.main.BaseModel):
23class RowObject(BaseModel):
24	"""
25		Object representing a database row.
26	"""
27
28	count : int
29	"""
30		The count (primary key) of the item.
31	"""
32
33	sender : str
34	"""
35		The sender of the message.
36	""" 
37
38	recipient : str
39	"""
40		The recipient of the message
41	"""
42
43	time : int
44	"""
45		The time (unix timestamp) the message was received/ sent.
46	"""
47
48	message : AgentMessage
49	"""
50		The message  received/ sent.
51	"""
52
53	processed : bool
54	"""
55		Did the management process the message, i.e., did the tasks necessary for this message (mostly only relevant for received messages).
56	"""

Object representing a database row.

count: int

The count (primary key) of the item.

sender: str

The sender of the message.

recipient: str

The recipient of the message

time: int

The time (unix timestamp) the message was received/ sent.

The message received/ sent.

processed: bool

Did the management process the message, i.e., did the tasks necessary for this message (mostly only relevant for received messages).

Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
class DB:
 58class DB():
 59
 60	_DB_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
 61	
 62	def __init__(self):
 63		self.db = sqlite3.connect(
 64			os.path.join(PERSIST_PATH, 'messages.db'),
 65			check_same_thread=False
 66		) 
 67		self.db.row_factory = sqlite3.Row
 68
 69		self.dblock = Lock()
 70		atexit.register(lambda db : db.close(), self.db)
 71
 72		self._assure_tables()
 73
 74	def _assure_tables(self):
 75		self.dblock.acquire()
 76		with self.db:
 77			self.db.execute("""CREATE TABLE IF NOT EXISTS Messages ( 
 78				count INTEGER PRIMARY KEY AUTOINCREMENT,
 79				id TEXT, 
 80				sender TEXT,
 81			 	recipient TEXT,
 82				time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 83				json BLOB,
 84				processed BOOL DEFAULT FALSE
 85			)""")
 86		self.dblock.release()
 87
 88	@validate_call
 89	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
 90		self.dblock.acquire()
 91		with self.db:
 92			self.db.execute(
 93				"""INSERT INTO Messages (
 94						id, sender, recipient, json, processed 
 95					) VALUES (
 96						:id, :sender, :recipient, :json, :processed 
 97				)""", {
 98					"id" : message.id,
 99					"sender" : sender,
100					"recipient" : recipient,
101					"json" : message.model_dump_json(),
102					"processed" : processed
103				})
104			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
105		self.dblock.release()
106
107		return new_count['last']
108
109	@validate_call
110	def set_processed(self, count:int, processed:bool=True) -> bool:
111		self.dblock.acquire()
112		with self.db:
113			try:
114				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
115				return True
116			except:
117				return False
118		self.dblock.release()
119
120	def __iter__(self) -> Generator[RowObject, None, None]:
121		yield from self.iterate()
122
123	@validate_call
124	def iterate(self,
125		id:str|None=None, sender:str|None=None, recipient:str|None=None,
126		processed:bool|None=None,
127		time_after:int|None=None, time_before:int|None=None,
128		limit:int=20, offset:int=0
129	) -> Generator[RowObject, None, None]:
130		
131		where = []
132		params = {
133			"lim": limit,
134			"off": offset
135		}
136
137		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
138			if not v is None:
139				where.append('{} = :{}'.format(n,n))
140				params[n] = v
141
142		if time_after:
143			where.append("time > :t_after")
144			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
145
146		if time_before:
147			where.append("time < :t_before")
148			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
149
150		if len(where) > 0:
151			where_clause = "WHERE " + (' AND '.join(where))
152		else:
153			where_clause = ""
154
155		with self.db:
156			for row in self.db.execute(
157				"SELECT * FROM Messages {} LIMIT :lim OFFSET :off".format(where_clause),
158				params
159			):
160				yield self._create_row_object(row)
161
162	def _create_row_object(self, row:sqlite3.Row) -> RowObject:
163		return RowObject(
164			count=row['count'],
165			sender=row['sender'],
166			recipient=row['recipient'],
167			time=int(datetime.strptime(row['time'], self._DB_TIME_FORMAT).timestamp()),
168			message=AgentMessage.model_construct(row['json']),
169			processed=row['processed']
170		)
171
172	def by_count(self, count:int) -> RowObject|None:
173		with self.db:
174			try:
175				return self._create_row_object(
176					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
177				)
178			except:
179				return None
db
dblock
@validate_call
def add_message( self, sender: str, recipient: str, message: ums.utils.types.AgentMessage, processed: bool = False) -> int:
 88	@validate_call
 89	def add_message(self, sender:str, recipient:str, message:AgentMessage, processed:bool=False) -> int:
 90		self.dblock.acquire()
 91		with self.db:
 92			self.db.execute(
 93				"""INSERT INTO Messages (
 94						id, sender, recipient, json, processed 
 95					) VALUES (
 96						:id, :sender, :recipient, :json, :processed 
 97				)""", {
 98					"id" : message.id,
 99					"sender" : sender,
100					"recipient" : recipient,
101					"json" : message.model_dump_json(),
102					"processed" : processed
103				})
104			new_count = self.db.execute("SELECT LAST_INSERT_ROWID() as last").fetchone()
105		self.dblock.release()
106
107		return new_count['last']
@validate_call
def set_processed(self, count: int, processed: bool = True) -> bool:
109	@validate_call
110	def set_processed(self, count:int, processed:bool=True) -> bool:
111		self.dblock.acquire()
112		with self.db:
113			try:
114				self.db.execute("UPDATE Messages SET processed = ? WHERE count = ?", (processed, count))
115				return True
116			except:
117				return False
118		self.dblock.release()
@validate_call
def iterate( self, id: str | None = None, sender: str | None = None, recipient: str | None = None, processed: bool | None = None, time_after: int | None = None, time_before: int | None = None, limit: int = 20, offset: int = 0) -> Generator[RowObject, NoneType, NoneType]:
123	@validate_call
124	def iterate(self,
125		id:str|None=None, sender:str|None=None, recipient:str|None=None,
126		processed:bool|None=None,
127		time_after:int|None=None, time_before:int|None=None,
128		limit:int=20, offset:int=0
129	) -> Generator[RowObject, None, None]:
130		
131		where = []
132		params = {
133			"lim": limit,
134			"off": offset
135		}
136
137		for v,n in ((id,'id'), (sender,'sender'), (recipient,'recipient'), (processed,'processed')):
138			if not v is None:
139				where.append('{} = :{}'.format(n,n))
140				params[n] = v
141
142		if time_after:
143			where.append("time > :t_after")
144			params['t_after'] = datetime.fromtimestamp(time_after).strftime(self._DB_TIME_FORMAT)
145
146		if time_before:
147			where.append("time < :t_before")
148			params['t_before'] = datetime.fromtimestamp(time_before).strftime(self._DB_TIME_FORMAT)
149
150		if len(where) > 0:
151			where_clause = "WHERE " + (' AND '.join(where))
152		else:
153			where_clause = ""
154
155		with self.db:
156			for row in self.db.execute(
157				"SELECT * FROM Messages {} LIMIT :lim OFFSET :off".format(where_clause),
158				params
159			):
160				yield self._create_row_object(row)
def by_count(self, count: int) -> RowObject | None:
172	def by_count(self, count:int) -> RowObject|None:
173		with self.db:
174			try:
175				return self._create_row_object(
176					self.db.execute("SELECT * FROM Messages WHERE count = ?", (count,)).fetchone()
177				)
178			except:
179				return None