268 lines
8.5 KiB
Python
268 lines
8.5 KiB
Python
# Agenten Plattform
|
|
#
|
|
# (c) 2024 Magnus Bender
|
|
# Institute of Humanities-Centered Artificial Intelligence (CHAI)
|
|
# Universitaet Hamburg
|
|
# https://www.chai.uni-hamburg.de/~bender
|
|
#
|
|
# source code released under the terms of GNU Public License Version 3
|
|
# https://www.gnu.org/licenses/gpl-3.0.txt
|
|
|
|
import os, re
|
|
from typing import List
|
|
|
|
import requests
|
|
from fastapi import BackgroundTasks
|
|
|
|
from ums.management.db import DB
|
|
from ums.utils import AgentMessage, AgentResponse, logger, RiddleData, RiddleSolution
|
|
|
|
class MessageProcessor():
|
|
|
|
SOLUTION_MAX_TRIALS = int(os.environ.get('SOLUTION_MAX_TRIALS', 5))
|
|
MESSAGE_MAX_CONTACTS = int(os.environ.get('MESSAGE_MAX_CONTACTS', 100))
|
|
|
|
REQUIRE_FULL_EXTRACT = os.environ.get('REQUIRE_FULL_EXTRACT', 'false').lower() == 'true'
|
|
REQUIRE_FULL_SOLVE = os.environ.get('REQUIRE_FULL_SOLVE', 'false').lower() == 'true'
|
|
|
|
MANAGEMENT_URL = os.environ.get('MANAGEMENT_URL', 'http://127.0.0.1:80').strip().strip('/')
|
|
|
|
AGENTS_PROCESS = tuple(map(
|
|
lambda s:s.strip().strip('/'),
|
|
os.environ.get('AGENTS_PROCESS', '').split(',')
|
|
))
|
|
AGENTS_SOLVE = tuple(map(
|
|
lambda s:s.strip().strip('/'),
|
|
os.environ.get('AGENTS_SOLVE', '').split(',')
|
|
))
|
|
AGENTS_GATEKEEPER = tuple(map(
|
|
lambda s:s.strip().strip('/'),
|
|
os.environ.get('AGENTS_GATEKEEPER', '').split(',')
|
|
))
|
|
|
|
def __init__(self, db:DB):
|
|
self.db = db
|
|
self.management_name = self._get_name(self.MANAGEMENT_URL)
|
|
|
|
if len(self.AGENTS_PROCESS) == 0:
|
|
logger.warning(f"Not Process Agent (AGENTS_PROCESS) found, this may be a problem!")
|
|
if len(self.AGENTS_SOLVE) == 0:
|
|
logger.warning(f"Not Solve Agent (AGENTS_SOLVE) found, this may be a problem!")
|
|
if len(self.AGENTS_GATEKEEPER) == 0:
|
|
logger.warning(f"Not Gatekeeper Agent (AGENTS_GATEKEEPER) found, this may be a problem!")
|
|
|
|
def _get_name(self, url:str) -> str:
|
|
m = re.match(r'^https?://([^:]*)(?::(\d+))?$', url)
|
|
return "unknown" if m == None else m.group(1)
|
|
|
|
def new_message(self,
|
|
sender:str, receiver:str, message:AgentMessage,
|
|
background_tasks: BackgroundTasks
|
|
) -> AgentResponse:
|
|
|
|
try:
|
|
db_count = self.db.add_message(sender, receiver, message)
|
|
background_tasks.add_task(self._process_message, db_count)
|
|
|
|
return AgentResponse(
|
|
count=db_count,
|
|
msg="Added message to queue"
|
|
)
|
|
except Exception as e:
|
|
return AgentResponse(
|
|
count=-1,
|
|
error=True,
|
|
error_msg=str(e)
|
|
)
|
|
|
|
def _process_message(self, count:int, ignore_processed:bool=False):
|
|
db_message = self.db.by_count(count)
|
|
|
|
if db_message.processed and not ignore_processed:
|
|
# do not process processed messages again
|
|
return
|
|
|
|
# now message processed!
|
|
self.db.set_processed(count=count, processed=True)
|
|
|
|
# increment contacts counter
|
|
db_message.message.contacts += 1
|
|
if db_message.message.contacts > self.MESSAGE_MAX_CONTACTS:
|
|
logger.warning(f"Message reached max number of contacts! {db_message.message.id}, {count}")
|
|
return
|
|
|
|
# check which step/ state the message requires the management to do
|
|
# -> IF
|
|
if db_message.message.status.extract.required and not db_message.message.status.extract.finished:
|
|
# send to extract agents
|
|
self._send_messages(self.AGENTS_PROCESS, db_message.message)
|
|
return
|
|
|
|
# combine different extractions in data items
|
|
# will update items in `db_message.message.data`
|
|
fully_extracted = self._add_extractions(db_message.message.id, db_message.message.data)
|
|
if self.REQUIRE_FULL_EXTRACT and not fully_extracted:
|
|
logger.warning(f"Postpone message, wait for full extract of items! {db_message.message.id}, {count}")
|
|
return
|
|
|
|
# -> EL IF
|
|
if db_message.message.status.solve.required and not db_message.message.status.solve.finished:
|
|
# send to solve agents
|
|
self._send_messages(self.AGENTS_SOLVE, db_message.message)
|
|
return
|
|
|
|
# combine different solutions
|
|
# will add solutions received before to `db_message.message.solution`
|
|
fully_solved = self._add_solutions(db_message.message.id, db_message.message.solution, db_message.message.status.trial)
|
|
if self.REQUIRE_FULL_SOLVE and not fully_solved:
|
|
logger.warning(f"Postpone message, wait for all solutions of riddle! {db_message.message.id}, {count}")
|
|
return
|
|
|
|
# -> EL IF
|
|
if db_message.message.status.validate.required and not db_message.message.status.validate.finished:
|
|
# send to solve agents
|
|
self._send_messages(self.AGENTS_GATEKEEPER, db_message.message)
|
|
return
|
|
|
|
# -> ELSE
|
|
# all steps "done"
|
|
|
|
# validate not required? (then solved will never be set to true, thus set it here)
|
|
if not db_message.message.status.validate.required:
|
|
db_message.message.status.solved = True
|
|
|
|
if db_message.message.status.solved:
|
|
# yay, message is solved
|
|
self.db.set_solution(count=count, solution=True);
|
|
else:
|
|
# not solved, but all steps done
|
|
self.db.set_solution(count=count, solution=False);
|
|
|
|
# try again
|
|
self._do_again(db_message.message)
|
|
|
|
def _hash_solution(self, s:RiddleSolution) -> int:
|
|
return hash((s.solution, s.explanation, tuple((d.file_plain, d.type) for d in s.used_data)))
|
|
|
|
def _add_solutions(self, riddle_id:str, solution:List[RiddleSolution], trial:int) -> bool:
|
|
# do not do anything, if all solutions available
|
|
if len(solution) >= len(self.AGENTS_SOLVE):
|
|
return True
|
|
|
|
contained = set(self._hash_solution(s) for s in solution)
|
|
|
|
# search db for solutions from before
|
|
for row in self.db.iterate(
|
|
id=riddle_id,
|
|
limit=min(self.db.len(id=riddle_id), 250)
|
|
):
|
|
# make sure to only use solutions from same "trial"
|
|
if row.message.status.trial == trial:
|
|
for s in row.message.solution:
|
|
h = self._hash_solution(s)
|
|
if h not in contained:
|
|
# add the 'new' solution
|
|
solution.append(s)
|
|
contained.add(h)
|
|
|
|
# all solutions found ?
|
|
if len(solution) >= len(self.AGENTS_SOLVE):
|
|
break
|
|
|
|
return len(solution) >= len(self.AGENTS_SOLVE)
|
|
|
|
def _hash_data(self, d:RiddleData) -> int:
|
|
return hash((d.file_plain, d.type, d.prompt))
|
|
|
|
def _add_extractions(self, riddle_id:str, data:List[RiddleData]) -> bool:
|
|
# get all the data items without extraction
|
|
empty_data = {}
|
|
for i, d in enumerate(data):
|
|
if d.file_extracted is None:
|
|
empty_data[self._hash_data(d)] = i
|
|
|
|
# do not do anything if fully extracted
|
|
if len(empty_data) == 0:
|
|
return True
|
|
|
|
# search db for extractions already available
|
|
for row in self.db.iterate(
|
|
id=riddle_id,
|
|
limit=min(self.db.len(id=riddle_id), 250)
|
|
):
|
|
# check for required extraction
|
|
for d in row.message.data:
|
|
# already extracted ?
|
|
# extraction file exists ?
|
|
# one of the items, we do not have extractions for ?
|
|
# the same data item ?
|
|
if not d.file_extracted is None \
|
|
and not d.file_extracted.startswith("missing:") \
|
|
and self._hash_data(d) in empty_data:
|
|
# copy the reference to the extracted data
|
|
data[empty_data[self._hash_data(d)]].file_extracted = d.file_extracted
|
|
# remove from items we need extracted data for
|
|
del empty_data[self._hash_data(d)]
|
|
|
|
# break if all extractions found
|
|
if len(empty_data) == 0:
|
|
break
|
|
|
|
return len(empty_data) == 0 # fully extracted
|
|
|
|
def _do_again(self, message:AgentMessage):
|
|
if message.status.trial < self.SOLUTION_MAX_TRIALS:
|
|
# try again, recycle message
|
|
|
|
# require steps again
|
|
if message.status.extract.required:
|
|
message.status.extract.finished = False
|
|
if message.status.solve.required:
|
|
message.status.solve.finished = False
|
|
if message.status.validate.required:
|
|
message.status.validate.finished = False
|
|
|
|
# increment trial
|
|
message.status.trial += 1
|
|
|
|
# append current solution(s) als old one(s)
|
|
if len(message.solution) > 0:
|
|
message.riddle.solutions_before.extend(
|
|
message.solution
|
|
)
|
|
# reset current solution
|
|
message.solution = []
|
|
|
|
# add the riddle as new to management
|
|
self._send_message(self.MANAGEMENT_URL, message)
|
|
|
|
else:
|
|
logger.info(f"Unsolved riddle after max number of trials: {message.id}")
|
|
|
|
def _send_messages(self, recipients:List[str], message:AgentMessage) -> bool:
|
|
ok = True
|
|
for r in recipients:
|
|
ok = ok and self._send_message(r, message)
|
|
return ok
|
|
|
|
def _send_message(self, recipient:str, message:AgentMessage) -> bool:
|
|
db_count = self.db.add_message(
|
|
sender=self.management_name,
|
|
recipient=self._get_name(recipient),
|
|
message=message,
|
|
processed=False
|
|
)
|
|
|
|
r = requests.post(
|
|
"{}/message".format(recipient),
|
|
data=message.model_dump_json(),
|
|
headers={"accept" : "application/json", "content-type" : "application/json"}
|
|
)
|
|
|
|
if r.status_code == 200:
|
|
self.db.set_processed(db_count, processed=True)
|
|
return True
|
|
else:
|
|
logger.warning(f"Error sending message to: {recipient} {(r.text, r.headers)}")
|
|
return False
|
|
|