Agent should work
This commit is contained in:
95
ums/agent/process.py
Normal file
95
ums/agent/process.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Agenten Plattform
|
||||
#
|
||||
# (c) 2024 Magnus Bender
|
||||
# Institute of Humanities-Centered Artificial Intelligence (CHAI)
|
||||
# Universitaet Hamburg
|
||||
# https://www.chai.uni-hamburg.de/~bender
|
||||
#
|
||||
# source code released under the terms of GNU Public License Version 3
|
||||
# https://www.gnu.org/licenses/gpl-3.0.txt
|
||||
|
||||
import os, importlib
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
from ums.agent.agent import BasicAgent, AgentCapability, ExtractAgent, SolveAgent, GatekeeperAgent
|
||||
from ums.utils import AgentMessage, AgentResponse, logger
|
||||
|
||||
class MessageProcessor():
|
||||
|
||||
MANAGEMENT_URL = os.environ.get('MANAGEMENT_URL', 'http://127.0.0.1:80').strip().strip('/')
|
||||
AGENTS_LIST = os.environ.get('AGENTS_LIST', 'ums.example.example:AGENT_CLASSES').strip()
|
||||
|
||||
def __init__(self):
|
||||
self.counts = 0
|
||||
|
||||
module_name, var_name = self.AGENTS_LIST.split(':')
|
||||
agents_module = importlib.import_module(module_name)
|
||||
|
||||
self.agent_classes:List[BasicAgent] = getattr(agents_module, var_name)
|
||||
self.extract_agents:List[ExtractAgent] = list(filter(
|
||||
lambda ac: ac.agent_capability() == AgentCapability.EXTRACT,
|
||||
self.agent_classes
|
||||
))
|
||||
self.solve_agents:List[SolveAgent] = list(filter(
|
||||
lambda ac: ac.agent_capability() == AgentCapability.SOLVE,
|
||||
self.agent_classes
|
||||
))
|
||||
self.gatekeeper_agents:List[GatekeeperAgent] = list(filter(
|
||||
lambda ac: ac.agent_capability() == AgentCapability.GATEKEEPER,
|
||||
self.agent_classes
|
||||
))
|
||||
|
||||
def new_message(self, message:AgentMessage, background_tasks: BackgroundTasks) -> AgentResponse:
|
||||
enqueued = False
|
||||
|
||||
if message.status.extract.required and not message.status.extract.finished:
|
||||
# send to extract agents
|
||||
if len(self.extract_agents) > 0:
|
||||
data_types = set( d.type for d in message.data )
|
||||
for ac in self.extract_agents:
|
||||
if ac.extract_type() in data_types:
|
||||
background_tasks.add_task(ac, message, self._send_message)
|
||||
enqueued = True
|
||||
|
||||
elif message.status.solve.required and not message.status.solve.finished:
|
||||
# send to solve agents
|
||||
if len(self.solve_agents) > 0:
|
||||
for sa in self.solve_agents:
|
||||
background_tasks.add_task(sa, message, self._send_message)
|
||||
enqueued = True
|
||||
|
||||
elif message.status.validate.required and not message.status.validate.finished:
|
||||
# send to solve agents
|
||||
if len(self.gatekeeper_agents) > 0:
|
||||
for ga in self.gatekeeper_agents:
|
||||
background_tasks.add_task(ga, message, self._send_message)
|
||||
enqueued = True
|
||||
|
||||
logger.debug(
|
||||
("Added to queue" if enqueued else "No agent found to queue message.") +
|
||||
f"ID: {message.id} Count: {self.counts}"
|
||||
)
|
||||
|
||||
self.counts += 1
|
||||
return AgentResponse(
|
||||
count=self.counts-1,
|
||||
msg="Added to queue" if enqueued else "",
|
||||
error=not enqueued,
|
||||
error_msg=None if enqueued else "No agent found to queue message."
|
||||
)
|
||||
|
||||
def _send_message(self, message:AgentMessage) -> bool:
|
||||
r = requests.post(
|
||||
"{}/message".format(self.MANAGEMENT_URL),
|
||||
data=message.model_dump_json(),
|
||||
headers={"accept" : "application/json", "content-type" : "application/json"}
|
||||
)
|
||||
|
||||
if r.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Error sending message to management! {(r.text, r.headers)}")
|
||||
return False
|
Reference in New Issue
Block a user