100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
# Agenten Plattform
|
|
#
|
|
# (c) 2024 Magnus Bender
|
|
# Institute of Humanities-Centered Artificial Intelligence (CHAI)
|
|
# Universitaet Hamburg
|
|
# https://www.chai.uni-hamburg.de/~bender
|
|
#
|
|
# source code released under the terms of GNU Public License Version 3
|
|
# https://www.gnu.org/licenses/gpl-3.0.txt
|
|
|
|
import os, importlib
|
|
from typing import List
|
|
|
|
import requests
|
|
from fastapi import BackgroundTasks
|
|
|
|
from ums.agent.agent import BasicAgent, AgentCapability, ExtractAgent, SolveAgent, GatekeeperAgent
|
|
from ums.utils import AgentMessage, AgentResponse, logger
|
|
|
|
class MessageProcessor():
|
|
|
|
MANAGEMENT_URL = os.environ.get('MANAGEMENT_URL', 'http://127.0.0.1:80').strip().strip('/')
|
|
AGENTS_LIST = os.environ.get('AGENTS_LIST', 'ums.example.example:AGENT_CLASSES').strip()
|
|
|
|
def __init__(self, disable_messages:bool=False):
|
|
self.counts = 0
|
|
self.disable_messages = disable_messages
|
|
|
|
module_name, var_name = self.AGENTS_LIST.split(':')
|
|
agents_module = importlib.import_module(module_name)
|
|
|
|
self.agent_classes:List[BasicAgent] = getattr(agents_module, var_name)
|
|
self.extract_agents:List[ExtractAgent] = list(filter(
|
|
lambda ac: ac.agent_capability() == AgentCapability.EXTRACT,
|
|
self.agent_classes
|
|
))
|
|
self.solve_agents:List[SolveAgent] = list(filter(
|
|
lambda ac: ac.agent_capability() == AgentCapability.SOLVE,
|
|
self.agent_classes
|
|
))
|
|
self.gatekeeper_agents:List[GatekeeperAgent] = list(filter(
|
|
lambda ac: ac.agent_capability() == AgentCapability.GATEKEEPER,
|
|
self.agent_classes
|
|
))
|
|
|
|
def new_message(self, message:AgentMessage, background_tasks: BackgroundTasks) -> AgentResponse:
|
|
enqueued = False
|
|
|
|
if message.status.extract.required and not message.status.extract.finished:
|
|
# send to extract agents
|
|
if len(self.extract_agents) > 0:
|
|
data_types = set( d.type for d in message.data )
|
|
for ac in self.extract_agents:
|
|
if ac.extract_type() in data_types:
|
|
background_tasks.add_task(ac, message, self._send_message)
|
|
enqueued = True
|
|
|
|
elif message.status.solve.required and not message.status.solve.finished:
|
|
# send to solve agents
|
|
if len(self.solve_agents) > 0:
|
|
for sa in self.solve_agents:
|
|
background_tasks.add_task(sa, message, self._send_message)
|
|
enqueued = True
|
|
|
|
elif message.status.validate.required and not message.status.validate.finished:
|
|
# send to solve agents
|
|
if len(self.gatekeeper_agents) > 0:
|
|
for ga in self.gatekeeper_agents:
|
|
background_tasks.add_task(ga, message, self._send_message)
|
|
enqueued = True
|
|
|
|
logger.debug(
|
|
("Added to queue" if enqueued else "No agent found to queue message.") +
|
|
f"ID: {message.id} Count: {self.counts}"
|
|
)
|
|
|
|
self.counts += 1
|
|
return AgentResponse(
|
|
count=self.counts-1,
|
|
msg="Added to queue" if enqueued else "",
|
|
error=not enqueued,
|
|
error_msg=None if enqueued else "No agent found to queue message."
|
|
)
|
|
|
|
def _send_message(self, message:AgentMessage) -> bool:
|
|
if not self.disable_messages:
|
|
r = requests.post(
|
|
"{}/message".format(self.MANAGEMENT_URL),
|
|
data=message.model_dump_json(),
|
|
headers={"accept" : "application/json", "content-type" : "application/json"}
|
|
)
|
|
|
|
if r.status_code == 200:
|
|
return True
|
|
else:
|
|
logger.warning(f"Error sending message to management! {(r.text, r.headers)}")
|
|
return False
|
|
else:
|
|
print("\tMessages disabled: Requested to send message to management:")
|
|
print(message.model_dump_json(indent=2)) |