add: agent template/prompt/tools
This commit is contained in:
6
app.py
6
app.py
@@ -2,11 +2,11 @@ import threading
|
||||
|
||||
from mcp_service import start_mcp
|
||||
from api_service import start_api
|
||||
|
||||
|
||||
|
||||
from utils.db_connections import init_db
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_db()
|
||||
|
||||
t_mcp = threading.Thread(target=start_mcp, daemon=True)
|
||||
t_api = threading.Thread(target=start_api, daemon=True)
|
||||
t_mcp.start()
|
||||
|
||||
118
l1.py
Normal file
118
l1.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import asyncio
|
||||
from langchain.agents import create_react_agent, AgentExecutor
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_core.tools import tool
|
||||
from langchain.memory import MongoDBChatMessageHistory, ConversationBufferWindowMemory
|
||||
|
||||
from agents.agent_templates import AgentTemplate
|
||||
from agents.output_parser.MCPReactParser import MCPReactParser
|
||||
|
||||
|
||||
@tool
|
||||
def get_daily_string(salt: str):
|
||||
"""
|
||||
:param salt: a random string
|
||||
:return:get today's string
|
||||
"""
|
||||
return "assss"
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
temp1 = AgentTemplate()
|
||||
temp1.set_model("gpt-4.1-nano", "openai", "")
|
||||
temp1.set_mcp_server("terminal_tools", {
|
||||
'transport': 'sse',
|
||||
'headers': {
|
||||
"X-Api-Key": "bd303c2f2a515b4ce70c1f07ee85530c2"
|
||||
},
|
||||
'url': 'http://localhost:5050/sse'
|
||||
})
|
||||
|
||||
|
||||
mcp_client = MultiServerMCPClient({
|
||||
'terminal_tool':{
|
||||
'transport': 'sse',
|
||||
'headers': {
|
||||
"X-Api-Key": "bd303c2f2a515b4ce70c1f07ee85530c2"
|
||||
},
|
||||
'url': 'http://localhost:5050/sse'
|
||||
}
|
||||
})
|
||||
|
||||
tools = await mcp_client.get_tools()
|
||||
xtools = []
|
||||
for t in tools:
|
||||
if t.name == 'CreateTerminalSession':
|
||||
xtools.append(t)
|
||||
xtools.append(get_daily_string)
|
||||
|
||||
print(xtools)
|
||||
prompt = PromptTemplate.from_template("""
|
||||
|
||||
{system}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the question you must answer
|
||||
If you want to use tools:
|
||||
Thought: always reason what to do
|
||||
Action: the action to take, must be one of [{tool_names}]
|
||||
Action Input: a **JSON object** with named arguments required by the tool
|
||||
Observation: the result of the action
|
||||
If no tool is needed:
|
||||
Thought: what you are thinking
|
||||
... (this Thought/Action/... can repeat N times)
|
||||
Final Answer: the final answer to the original question
|
||||
|
||||
Here is the conversation history:
|
||||
{chat_history}
|
||||
|
||||
User message: {user}
|
||||
{agent_scratchpad}
|
||||
""")
|
||||
op = MCPReactParser()
|
||||
agent = create_react_agent(model_openai, xtools, prompt, output_parser=op)
|
||||
message_history = MongoDBChatMessageHistory(
|
||||
connection_string=MONGO_CONNECTION_STRING,
|
||||
session_id=USER_SESSION_ID,
|
||||
database_name=MONGO_DB_NAME,
|
||||
collection_name=MONGO_COLLECTION_NAME
|
||||
)
|
||||
mongodb_memory = ConversationBufferWindowMemory(
|
||||
chat_memory=message_history,
|
||||
memory_key=MEMORY_KEY,
|
||||
k=10,
|
||||
return_messages=False,
|
||||
)
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=xtools,
|
||||
memory=mongodb_memory,
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
|
||||
)
|
||||
response = await agent_executor.ainvoke({
|
||||
'system': 'you are a helpful assistant',
|
||||
'user': 'hello, please create a terminal session with label \'x\' and show me the session id',
|
||||
})
|
||||
print(response)
|
||||
r2 = await agent_executor.ainvoke({
|
||||
'system': 'you are a helpful assistant',
|
||||
'user': 'hello, please show me today\'s daily string'
|
||||
})
|
||||
print(r2)
|
||||
#print(s1)
|
||||
#print(s2)
|
||||
#print(s3)
|
||||
#print(s4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,3 @@
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
managed_agents = {}
|
||||
|
||||
36
src/agents/agent_tasks/scan_file.py
Normal file
36
src/agents/agent_tasks/scan_file.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from langchain_core.tools import tool
|
||||
|
||||
scan_file_task_msg = """
|
||||
Your task is to scan a file and generate knowledge abstract for it
|
||||
code base and file path should be provided by user msg, if not, raise an error
|
||||
|
||||
The workflow is:
|
||||
1. determine if abstract of file exists in db
|
||||
- if exists, check if the abstract is outdated
|
||||
- if not outdated, call res_tool_scan_result tool with status "success" and finish the workflow
|
||||
2. determine the type of the file(source code, config file, binary lib/executable)
|
||||
- if the file is binary
|
||||
- try find documentation of the binary in readme file
|
||||
- try execute it with arguments like `--help` in a sandbox
|
||||
- if you still can not determine the usage of the binary, call res_tool_scan_result tool with status "unknown" and finish the workflow
|
||||
- if the file is source code/config file
|
||||
- breakdown code by scopes, blocks into segments with start line and end line
|
||||
- describe usage and/or importance of each segment with tool upload_result_to_db
|
||||
- call res_tool_scan_result tool with status "success" and finish the workflow
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@tool
|
||||
def res_tool_scan_result(session_id: str, status: str):
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'status': status
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
def upload_result_to_db(codebase: str, file_path: str, segments):
|
||||
pass
|
||||
|
||||
94
src/agents/agent_templates/__init__.py
Normal file
94
src/agents/agent_templates/__init__.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import uuid
|
||||
import os
|
||||
|
||||
from langchain.agents import create_react_agent, AgentExecutor
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
||||
|
||||
from agents.output_parser.MCPReactParser import MCPReactParser
|
||||
from agents.prompts import general_prompt
|
||||
from utils.model_connections import PROVIDER_API_KEYS
|
||||
|
||||
|
||||
class AgentTemplate:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = kwargs.get("model", None)
|
||||
self.mcp_servers = kwargs.get("mcp_servers", {})
|
||||
self.builtin_tools = kwargs.get("builtin_tools", [])
|
||||
self.prompt_template = kwargs.get("prompt_template", general_prompt)
|
||||
self.output_parser=kwargs.get("output_parser", MCPReactParser())
|
||||
self.history_window_size = kwargs.get("history_window_size", 20)
|
||||
self.template_params = kwargs.get("template_params", {})
|
||||
|
||||
def set_model(self, model, provider):
|
||||
self.model = init_chat_model(
|
||||
model,
|
||||
model_provider=provider,
|
||||
api_key=PROVIDER_API_KEYS[provider]
|
||||
)
|
||||
|
||||
def set_mcp_servers(self, mcp_servers):
|
||||
self.mcp_servers = mcp_servers
|
||||
|
||||
def set_mcp_server(self, name, mcp_server):
|
||||
self.mcp_servers[name] = mcp_server
|
||||
|
||||
def set_builtin_tools(self, builtin_tools):
|
||||
self.builtin_tools = builtin_tools
|
||||
|
||||
def add_builtin_tool(self, tool):
|
||||
self.builtin_tools.append(tool)
|
||||
|
||||
def set_prompt_template(self, prompt_template):
|
||||
self.prompt_template = prompt_template
|
||||
|
||||
def set_history_window_size(self, history_window_size):
|
||||
self.history_window_size = history_window_size
|
||||
|
||||
def set_template_params(self, template_params):
|
||||
self.template_params = template_params
|
||||
|
||||
def add_template_param(self, name, template_param):
|
||||
self.template_params[name] = template_param
|
||||
async def async_get_instance(self, session_id = str(uuid.uuid4())):
|
||||
mcp_client = MultiServerMCPClient(self.mcp_servers)
|
||||
tools = await mcp_client.get_tools() + self.builtin_tools
|
||||
prompt_template = self.prompt_template
|
||||
for param in self.template_params.keys():
|
||||
prompt_template = prompt_template.replace(f'{{{param}}}', self.template_params[param])
|
||||
|
||||
prompt = PromptTemplate.from_tools(prompt_template)
|
||||
history = MongoDBChatMessageHistory(
|
||||
connection_string=os.getenv("MONGODB_CONNECTION_STRING", ""),
|
||||
session_id=session_id,
|
||||
database_name="ckb",
|
||||
collection_name="session_history"
|
||||
)
|
||||
|
||||
memory = ConversationBufferWindowMemory(
|
||||
chat_memory=history,
|
||||
memory_key="chat_memory",
|
||||
k=self.history_window_size,
|
||||
return_messages=False,
|
||||
)
|
||||
|
||||
agent = create_react_agent(
|
||||
self.model,
|
||||
tools,
|
||||
prompt,
|
||||
output_parser=self.output_parser
|
||||
)
|
||||
|
||||
res = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
memory=memory,
|
||||
handle_parsing_errors=True,
|
||||
verbose=True
|
||||
)
|
||||
return res
|
||||
73
src/agents/output_parser/MCPReactParser.py
Normal file
73
src/agents/output_parser/MCPReactParser.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import re
|
||||
import json
|
||||
from langchain.agents import AgentOutputParser
|
||||
from langchain_core.outputs import Generation, ChatGeneration
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from typing import ClassVar, Any, Union
|
||||
|
||||
class MCPReactParser(AgentOutputParser):
|
||||
FINAL_ANSWER_ACTION: ClassVar[str] = "Final Answer"
|
||||
|
||||
def parser_result(self, result: list[Generation], *, partial: bool = False) -> Any:
|
||||
if not result or not isinstance(result[0], (Generation, ChatGeneration)):
|
||||
raise ValueError("Expected a single Generation or ChatGeneration")
|
||||
text = result[0].text
|
||||
return self.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_final_answer = self.Final_Answer_ACTION in text
|
||||
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
|
||||
if action_match:
|
||||
action = action_match.group(1).strip()
|
||||
action_input_raw = action_match.group(2).strip()
|
||||
|
||||
try:
|
||||
cleaned_input_str = action_input_raw
|
||||
if cleaned_input_str.startswith("```json"):
|
||||
cleaned_input_str = cleaned_input_str[len("```json"):].strip()
|
||||
elif cleaned_input_str.startswith("```"):
|
||||
cleaned_input_str = cleaned_input_str[:-len("```")].strip()
|
||||
if cleaned_input_str.endswith("```"):
|
||||
cleaned_input_str = cleaned_input_str[:-len("```")].strip()
|
||||
|
||||
if not cleaned_input_str:
|
||||
raise json.JSONDecodeError("Action Input is empty after stripping markdown.", cleaned_input_str, 0)
|
||||
|
||||
parsed_tool_input = json.loads(cleaned_input_str)
|
||||
return AgentAction(tool=action, tool_input=parsed_tool_input, log = text)
|
||||
except json.JSONDecodeError as e:
|
||||
observation_msg = (
|
||||
f"The Action Input for '{action}' was not valid JSON. "
|
||||
f"Please provide a correctly formatted JSON object (e.g., double quotes for keys and string values, no trailing commas). "
|
||||
f"Error: {e}. Input received: '{action_input_raw}'"
|
||||
)
|
||||
raise OutputParserException(
|
||||
f"Malformed JSON in Action Input for action '{action}': {action_input_raw}. Error: {e}",
|
||||
observation=observation_msg,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
|
||||
elif includes_final_answer:
|
||||
output_part = text.split(self.FINAL_ANSWER_ACTION, 1)[1].strip()
|
||||
return AgentFinish({"output": output_part}, text)
|
||||
|
||||
else:
|
||||
if not re.search(r"Action\s*\d*\s*:", text, re.DOTALL):
|
||||
observation_msg = "Invalid LLM output. The 'Action:' keyword was not found. Your response must include 'Action: <tool_name>' followed by 'Action Input: <JSON_args>', or 'Final Answer: <text>'."
|
||||
else:
|
||||
observation_msg = "Invalid LLM output. 'Action:' was found, but 'Action Input:' was missing or not correctly formatted on a new line after 'Action:'. Ensure 'Action Input:' is followed by a valid JSON object."
|
||||
raise OutputParserException(
|
||||
f"Could not parse LLM output: `{text}`",
|
||||
observation=observation_msg,
|
||||
llm_output=text,
|
||||
send_to_llm=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "mcp_react_parser"
|
||||
|
||||
43
src/agents/prompts/__init__.py
Normal file
43
src/agents/prompts/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
general_prompt = """
|
||||
You are {agent_role}
|
||||
|
||||
Your current task is
|
||||
|
||||
{task_description}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the question you must answer
|
||||
If you want to use tools:
|
||||
Thought: always reason what to do
|
||||
Action: the action to take, must be one of [{tool_names}]
|
||||
Action Input: a **JSON object** with named arguments required by the tool
|
||||
Observation: the result of the action
|
||||
If no tool is needed:
|
||||
Thought: what you are thinking
|
||||
... (the Thought or Thought/Action/... can repeat multiple times)
|
||||
|
||||
If you have a tool whose name starts with required_response_, you have to call it before final answer
|
||||
Thought: required response
|
||||
Action: required_response_...
|
||||
Action Input: ...
|
||||
Observation: ...
|
||||
|
||||
Final Answer: the final answer to the original question
|
||||
|
||||
Here is the conversation history:
|
||||
{chat_history}
|
||||
|
||||
User message: {user_msg}
|
||||
{agent_scratchpad}
|
||||
"""
|
||||
|
||||
def general_prompt_param_builder(agent_role, task_description):
|
||||
return {
|
||||
'agent_role': agent_role,
|
||||
'task_description': task_description,
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
|
||||
|
||||
general_sys_msg = """
|
||||
You are a {role}
|
||||
|
||||
Your task is {task}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
If you have any tool whose name starts with res_tool_
|
||||
You should call that tool right before the final answer
|
||||
e.g.
|
||||
Thought: calling mandatory res_tool
|
||||
Action: res_tool_general_response
|
||||
Action Input: ...
|
||||
Observation: ...
|
||||
Final Answer: ...
|
||||
|
||||
|
||||
Use the following format:
|
||||
```
|
||||
Question: the question you must answer
|
||||
|
||||
If you want to use tools:
|
||||
Thought: always reason what to do
|
||||
Action: the action to take, must be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
If no tool is needed:
|
||||
Thought: what you are thinking
|
||||
|
||||
(the Thought or Thought/Action/... can repeat multiple times))
|
||||
|
||||
Final Answer: Final response to the user message
|
||||
```
|
||||
The user message is {user_message}
|
||||
|
||||
"""
|
||||
@@ -3,6 +3,7 @@ from langchain_core.tools import tool
|
||||
@tool
|
||||
def res_tool_general_response(session_id: str, response: str):
|
||||
return {
|
||||
'type': 'response',
|
||||
'session_id': session_id,
|
||||
'response': response
|
||||
}
|
||||
9
src/agents/tools/throw_error.py
Normal file
9
src/agents/tools/throw_error.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from langchain_core.tools import tool
|
||||
@tool
|
||||
def throw_error(session_id: str, error_type: str, error_description: str):
|
||||
return {
|
||||
'type': 'error',
|
||||
'session_id': session_id,
|
||||
'error_type': str,
|
||||
'error_description': error_description
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
from mcp_service import mcp
|
||||
|
||||
|
||||
@mcp.prompt()
|
||||
def scan_file():
|
||||
return """
|
||||
|
||||
"""
|
||||
@@ -12,16 +12,14 @@ def init_db():
|
||||
if _client is None:
|
||||
with _lock:
|
||||
if _client is None:
|
||||
uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
|
||||
db_name = 'ckb'
|
||||
uri = os.getenv('MONGO_CONNECTION_STRING', 'mongodb://localhost:27017')
|
||||
db_name = os.getenv('CKB_DB_NAME', 'ckb')
|
||||
max_pool = 100
|
||||
|
||||
_client = MongoClient(uri, maxPoolSize=max_pool)
|
||||
|
||||
if db_name not in _client.list_database_names():
|
||||
tmp = _client[db_name].create_collection('_init')
|
||||
_client[db_name].drop_collection('_init')
|
||||
|
||||
_client[db_name].create_collection('session_history')
|
||||
_db = _client[db_name]
|
||||
return _db
|
||||
|
||||
|
||||
Reference in New Issue
Block a user