add: operations
This commit is contained in:
3
app.py
3
app.py
@@ -3,10 +3,11 @@ import threading
|
||||
from mcp_service import start_mcp
|
||||
from api_service import start_api
|
||||
from utils.db_connections import init_db
|
||||
from utils.ssh_connections import init_workspace
|
||||
|
||||
if __name__ == '__main__':
|
||||
init_db()
|
||||
|
||||
init_workspace()
|
||||
t_mcp = threading.Thread(target=start_mcp, daemon=True)
|
||||
t_api = threading.Thread(target=start_api, daemon=True)
|
||||
t_mcp.start()
|
||||
|
||||
107
l1.py
107
l1.py
@@ -1,13 +1,11 @@
|
||||
import asyncio
|
||||
from langchain.agents import create_react_agent, AgentExecutor
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_core.tools import tool
|
||||
from langchain.memory import MongoDBChatMessageHistory, ConversationBufferWindowMemory
|
||||
|
||||
from agents.agent_tasks.scan_file import tools
|
||||
from agents.agent_templates import AgentTemplate
|
||||
from agents.output_parser.MCPReactParser import MCPReactParser
|
||||
from utils.db_connections import init_db
|
||||
from agents.prompts import general_prompt_param_builder
|
||||
from utils.ssh_connections.ssh_operations import is_file_exist
|
||||
|
||||
|
||||
@tool
|
||||
@@ -21,8 +19,9 @@ def get_daily_string(salt: str):
|
||||
|
||||
async def main():
|
||||
|
||||
init_db()
|
||||
temp1 = AgentTemplate()
|
||||
temp1.set_model("gpt-4.1-nano", "openai", "")
|
||||
temp1.set_model("gpt-4.1-nano", "openai")
|
||||
temp1.set_mcp_server("terminal_tools", {
|
||||
'transport': 'sse',
|
||||
'headers': {
|
||||
@@ -30,89 +29,21 @@ async def main():
|
||||
},
|
||||
'url': 'http://localhost:5050/sse'
|
||||
})
|
||||
temp1.set_builtin_tools(tools)
|
||||
temp1.set_template_params(general_prompt_param_builder(
|
||||
agent_role="assistant",
|
||||
task_description="test, call update_code_result_to_db tool with arbitrary valid input"
|
||||
))
|
||||
|
||||
|
||||
mcp_client = MultiServerMCPClient({
|
||||
'terminal_tool':{
|
||||
'transport': 'sse',
|
||||
'headers': {
|
||||
"X-Api-Key": "bd303c2f2a515b4ce70c1f07ee85530c2"
|
||||
},
|
||||
'url': 'http://localhost:5050/sse'
|
||||
}
|
||||
agent = await temp1.async_get_instance()
|
||||
r1 = await agent.ainvoke({
|
||||
'user_msg': 'test'
|
||||
})
|
||||
|
||||
tools = await mcp_client.get_tools()
|
||||
xtools = []
|
||||
for t in tools:
|
||||
if t.name == 'CreateTerminalSession':
|
||||
xtools.append(t)
|
||||
xtools.append(get_daily_string)
|
||||
|
||||
print(xtools)
|
||||
prompt = PromptTemplate.from_template("""
|
||||
|
||||
{system}
|
||||
|
||||
You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Question: the question you must answer
|
||||
If you want to use tools:
|
||||
Thought: always reason what to do
|
||||
Action: the action to take, must be one of [{tool_names}]
|
||||
Action Input: a **JSON object** with named arguments required by the tool
|
||||
Observation: the result of the action
|
||||
If no tool is needed:
|
||||
Thought: what you are thinking
|
||||
... (this Thought/Action/... can repeat N times)
|
||||
Final Answer: the final answer to the original question
|
||||
|
||||
Here is the conversation history:
|
||||
{chat_history}
|
||||
|
||||
User message: {user}
|
||||
{agent_scratchpad}
|
||||
""")
|
||||
op = MCPReactParser()
|
||||
agent = create_react_agent(model_openai, xtools, prompt, output_parser=op)
|
||||
message_history = MongoDBChatMessageHistory(
|
||||
connection_string=MONGO_CONNECTION_STRING,
|
||||
session_id=USER_SESSION_ID,
|
||||
database_name=MONGO_DB_NAME,
|
||||
collection_name=MONGO_COLLECTION_NAME
|
||||
)
|
||||
mongodb_memory = ConversationBufferWindowMemory(
|
||||
chat_memory=message_history,
|
||||
memory_key=MEMORY_KEY,
|
||||
k=10,
|
||||
return_messages=False,
|
||||
)
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=xtools,
|
||||
memory=mongodb_memory,
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
|
||||
)
|
||||
response = await agent_executor.ainvoke({
|
||||
'system': 'you are a helpful assistant',
|
||||
'user': 'hello, please create a terminal session with label \'x\' and show me the session id',
|
||||
})
|
||||
print(response)
|
||||
r2 = await agent_executor.ainvoke({
|
||||
'system': 'you are a helpful assistant',
|
||||
'user': 'hello, please show me today\'s daily string'
|
||||
})
|
||||
print(r2)
|
||||
#print(s1)
|
||||
#print(s2)
|
||||
#print(s3)
|
||||
#print(s4)
|
||||
# for t in tools:
|
||||
# if t.name == 'update_code_result_to_db':
|
||||
# print(t)
|
||||
print(r1)
|
||||
print(is_file_exist("./Guide"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -1,12 +1,19 @@
|
||||
from langchain_core.tools import tool
|
||||
from typing import Dict, Optional, List, TypedDict
|
||||
from utils import operations
|
||||
from utils.db_connections import db_operations
|
||||
from utils.ssh_connections.ssh_operations import read_file_content, get_file_md5
|
||||
from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool
|
||||
from db_models.embedded_models import CodeSegment
|
||||
|
||||
scan_file_task_msg = """
|
||||
Your task is to scan a file and generate knowledge abstract for it
|
||||
scan_file_task_description = """
|
||||
Your task is to understand a file and generate knowledge abstract for it
|
||||
code base and file path should be provided by user msg, if not, raise an error
|
||||
|
||||
The workflow is:
|
||||
1. determine if abstract of file exists in db
|
||||
- if exists, check if the abstract is outdated
|
||||
- if exists, check if the abstract is outdated
|
||||
- if outdated, update the abstract
|
||||
- if not outdated, call res_tool_scan_result tool with status "success" and finish the workflow
|
||||
2. determine the type of the file(source code, config file, binary lib/executable)
|
||||
- if the file is binary
|
||||
@@ -17,20 +24,290 @@ The workflow is:
|
||||
- breakdown code by scopes, blocks into segments with start line and end line
|
||||
- describe usage and/or importance of each segment with tool upload_result_to_db
|
||||
- call res_tool_scan_result tool with status "success" and finish the workflow
|
||||
|
||||
|
||||
"""
|
||||
|
||||
class CodeSegmentDict(TypedDict):
|
||||
line_start: int
|
||||
line_end: int
|
||||
abstract: str
|
||||
links: List[str]
|
||||
|
||||
@tool
|
||||
def res_tool_scan_result(session_id: str, status: str):
|
||||
async def update_code_result_to_db(
|
||||
codebase: str,
|
||||
path: str,
|
||||
md5: str,
|
||||
abstract: str,
|
||||
segments: List[CodeSegmentDict]
|
||||
) -> bool:
|
||||
"""Update code file analysis results to database
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
md5: MD5 hash of the file
|
||||
abstract: Overall summary of the file
|
||||
segments: List of code segments, each containing:
|
||||
- line_start: Starting line number
|
||||
- line_end: Ending line number
|
||||
- abstract: Summary of this segment
|
||||
- links: List of related file links
|
||||
|
||||
Returns:
|
||||
bool: Whether the update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist
|
||||
"""
|
||||
codebase_obj = await db_operations.get_codebase(codebase)
|
||||
if not codebase_obj:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
code_segments = [
|
||||
CodeSegment(
|
||||
line_start=seg["line_start"],
|
||||
line_end=seg["line_end"],
|
||||
abstract=seg["abstract"],
|
||||
links=seg["links"]
|
||||
)
|
||||
for seg in segments
|
||||
]
|
||||
|
||||
code_file = CodeFile(
|
||||
codebase=codebase_obj,
|
||||
path=path,
|
||||
md5=md5,
|
||||
abstract=abstract,
|
||||
segments=code_segments,
|
||||
scanned=True
|
||||
)
|
||||
|
||||
return await db_operations.save_model(code_file)
|
||||
|
||||
@tool
|
||||
async def update_config_result_to_db(
|
||||
codebase: str,
|
||||
path: str,
|
||||
md5: str,
|
||||
abstract: str
|
||||
) -> bool:
|
||||
"""Update configuration file analysis results to database
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
md5: MD5 hash of the file
|
||||
abstract: Configuration summary including main config items and usage description
|
||||
|
||||
Returns:
|
||||
bool: Whether the update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist
|
||||
"""
|
||||
codebase_obj = await db_operations.get_codebase(codebase)
|
||||
if not codebase_obj:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
config_file = ConfigFile(
|
||||
codebase=codebase_obj,
|
||||
path=path,
|
||||
md5=md5,
|
||||
abstract=abstract,
|
||||
scanned=True
|
||||
)
|
||||
|
||||
return await db_operations.save_model(config_file)
|
||||
|
||||
@tool
|
||||
async def update_ignore_result_to_db(
|
||||
codebase: str,
|
||||
path: str,
|
||||
md5: str
|
||||
) -> bool:
|
||||
"""Update ignore file information to database
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
md5: MD5 hash of the file
|
||||
|
||||
Returns:
|
||||
bool: Whether the update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist
|
||||
"""
|
||||
codebase_obj = await db_operations.get_codebase(codebase)
|
||||
if not codebase_obj:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
ignore_file = IgnoreFile(
|
||||
codebase=codebase_obj,
|
||||
path=path,
|
||||
md5=md5
|
||||
)
|
||||
|
||||
return await db_operations.save_model(ignore_file)
|
||||
|
||||
@tool
|
||||
async def update_library_result_to_db(
|
||||
codebase: str,
|
||||
path: str,
|
||||
abstract: str
|
||||
) -> bool:
|
||||
"""Update binary library file analysis results to database
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
abstract: Library file summary including:
|
||||
- Main functionality of the library
|
||||
- Dependencies
|
||||
- Usage scenarios
|
||||
|
||||
Returns:
|
||||
bool: Whether the update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist
|
||||
"""
|
||||
codebase_obj = await db_operations.get_codebase(codebase)
|
||||
if not codebase_obj:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
library_file = BinaryLibrary(
|
||||
codebase=codebase_obj,
|
||||
path=path,
|
||||
abstract=abstract
|
||||
)
|
||||
|
||||
return await db_operations.save_model(library_file)
|
||||
|
||||
@tool
|
||||
async def update_tool_result_to_db(
|
||||
codebase: str,
|
||||
path: str,
|
||||
abstract: str
|
||||
) -> bool:
|
||||
"""Update binary tool file analysis results to database
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
abstract: Tool file summary including:
|
||||
- Main functionality of the tool
|
||||
- Command line argument descriptions
|
||||
- Usage examples
|
||||
|
||||
Returns:
|
||||
bool: Whether the update was successful
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist
|
||||
"""
|
||||
codebase_obj = await db_operations.get_codebase(codebase)
|
||||
if not codebase_obj:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
tool_file = BinaryTool(
|
||||
codebase=codebase_obj,
|
||||
path=path,
|
||||
abstract=abstract
|
||||
)
|
||||
|
||||
return await db_operations.save_model(tool_file)
|
||||
|
||||
@tool
|
||||
def required_response_scan_result(session_id: str, status: str) -> Dict[str, str]:
|
||||
"""Return scan result response
|
||||
|
||||
Args:
|
||||
session_id: ID of the scan session
|
||||
status: Status of the scan
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Response containing session ID and status
|
||||
"""
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'status': status
|
||||
}
|
||||
|
||||
@tool
|
||||
async def read_file(codebase: str, path: str) -> str:
|
||||
"""Read file content
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
|
||||
Returns:
|
||||
str: File content
|
||||
|
||||
Raises:
|
||||
ValueError: When the codebase does not exist or file reading fails
|
||||
"""
|
||||
root_path = await operations.get_codebase_root(codebase)
|
||||
if not root_path:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
full_path = f"{root_path}/{path}"
|
||||
result = read_file_content(full_path)
|
||||
|
||||
if result["status"] == "failure":
|
||||
raise ValueError(f"Failed to read file: {result['result']}")
|
||||
return result["result"]
|
||||
|
||||
@tool
|
||||
def upload_result_to_db(codebase: str, file_path: str, segments):
|
||||
pass
|
||||
async def is_outdated(codebase: str, path: str) -> bool:
|
||||
"""Check if file summary is outdated
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
|
||||
Returns:
|
||||
bool: True if file does not exist or MD5 does not match
|
||||
"""
|
||||
stored_md5 = await db_operations.get_file_md5(codebase, path)
|
||||
if not stored_md5:
|
||||
return True
|
||||
|
||||
root_path = await operations.get_codebase_root(codebase)
|
||||
if not root_path:
|
||||
raise ValueError(f"Codebase {codebase} not found")
|
||||
|
||||
full_path = f"{root_path}/{path}"
|
||||
result = get_file_md5(full_path)
|
||||
|
||||
if result["status"] == "failure":
|
||||
raise ValueError(f"Failed to get file MD5: {result['result']}")
|
||||
return stored_md5 != result["result"]
|
||||
|
||||
@tool
|
||||
async def get_abstract(codebase: str, path: str) -> Optional[str]:
|
||||
"""Get file summary
|
||||
|
||||
Args:
|
||||
codebase: Name of the codebase
|
||||
path: File path relative to codebase root
|
||||
|
||||
Returns:
|
||||
Optional[str]: File summary, None if it doesn't exist
|
||||
"""
|
||||
return await db_operations.get_file_document(codebase, path)
|
||||
|
||||
|
||||
|
||||
tools = [
|
||||
update_code_result_to_db,
|
||||
update_config_result_to_db,
|
||||
update_ignore_result_to_db,
|
||||
update_library_result_to_db,
|
||||
update_tool_result_to_db,
|
||||
required_response_scan_result,
|
||||
read_file,
|
||||
is_outdated,
|
||||
get_abstract
|
||||
]
|
||||
|
||||
|
||||
@@ -59,12 +59,13 @@ class AgentTemplate:
|
||||
mcp_client = MultiServerMCPClient(self.mcp_servers)
|
||||
tools = await mcp_client.get_tools() + self.builtin_tools
|
||||
prompt_template = self.prompt_template
|
||||
prompt_template = prompt_template.replace('{session_id}', session_id)
|
||||
for param in self.template_params.keys():
|
||||
prompt_template = prompt_template.replace(f'{{{param}}}', self.template_params[param])
|
||||
|
||||
prompt = PromptTemplate.from_tools(prompt_template)
|
||||
prompt = PromptTemplate.from_template(prompt_template)
|
||||
history = MongoDBChatMessageHistory(
|
||||
connection_string=os.getenv("MONGODB_CONNECTION_STRING", ""),
|
||||
connection_string=os.getenv("MONGO_CONNECTION_STRING", None),
|
||||
session_id=session_id,
|
||||
database_name="ckb",
|
||||
collection_name="session_history"
|
||||
@@ -72,7 +73,7 @@ class AgentTemplate:
|
||||
|
||||
memory = ConversationBufferWindowMemory(
|
||||
chat_memory=history,
|
||||
memory_key="chat_memory",
|
||||
memory_key="chat_history",
|
||||
k=self.history_window_size,
|
||||
return_messages=False,
|
||||
)
|
||||
|
||||
11
src/agents/agent_templates/task_agent_templates/__init__.py
Normal file
11
src/agents/agent_templates/task_agent_templates/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from agents.agent_tasks.scan_file import tools, scan_file_task_description
|
||||
from agents.agent_templates import AgentTemplate
|
||||
from agents.prompts import general_prompt_param_builder
|
||||
|
||||
scan_file_agent_template = AgentTemplate()
|
||||
scan_file_agent_template.set_model("gpt-4.1-nano", "openai")
|
||||
scan_file_agent_template.set_builtin_tools(tools)
|
||||
scan_file_agent_template.set_template_params(general_prompt_param_builder(
|
||||
agent_role="assistant",
|
||||
task_description=scan_file_task_description
|
||||
))
|
||||
@@ -16,7 +16,7 @@ class MCPReactParser(AgentOutputParser):
|
||||
return self.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_final_answer = self.Final_Answer_ACTION in text
|
||||
includes_final_answer = self.FINAL_ANSWER_ACTION in text
|
||||
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
general_prompt = """
|
||||
You are {agent_role}
|
||||
|
||||
Your current task is
|
||||
Session Id of current task is {session_id}
|
||||
|
||||
{task_description}
|
||||
|
||||
@@ -29,10 +28,11 @@ You are {agent_role}
|
||||
|
||||
Final Answer: the final answer to the original question
|
||||
|
||||
User message: {user_msg}
|
||||
|
||||
Here is the conversation history:
|
||||
{chat_history}
|
||||
|
||||
User message: {user_msg}
|
||||
{agent_scratchpad}
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
def res_tool_general_response(session_id: str, response: str):
|
||||
return {
|
||||
'type': 'response',
|
||||
'session_id': session_id,
|
||||
'response': response
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
from odmantic import Model
|
||||
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
|
||||
from odmantic import Model, ObjectId
|
||||
|
||||
class BinaryLibrary(Model):
|
||||
codebase: Codebase
|
||||
codebase_id: ObjectId
|
||||
type: str="library"
|
||||
path: str
|
||||
abstract: str
|
||||
abstract: str
|
||||
md5: str
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from odmantic import Model
|
||||
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
from odmantic import Model, ObjectId
|
||||
|
||||
|
||||
class BinaryTool(Model):
|
||||
codebase: Codebase
|
||||
codebase_id: ObjectId
|
||||
type: str="tool"
|
||||
path: str
|
||||
abstract: str
|
||||
md5: str
|
||||
@@ -1,13 +1,12 @@
|
||||
from odmantic import Model
|
||||
from odmantic import Model, ObjectId
|
||||
from typing import List
|
||||
|
||||
from db_models.embedded_models.CodeSegment import CodeSegment
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
|
||||
|
||||
class CodeFile(Model):
|
||||
codebase: Codebase
|
||||
type: str
|
||||
codebase_id: ObjectId
|
||||
type: str="code"
|
||||
path: str
|
||||
md5: str
|
||||
abstract: str
|
||||
|
||||
8
src/db_models/Codebase.py
Normal file
8
src/db_models/Codebase.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from odmantic import Model
|
||||
|
||||
|
||||
class Codebase(Model):
|
||||
name: str
|
||||
version: str
|
||||
branch: str
|
||||
repo: str
|
||||
@@ -1,11 +1,9 @@
|
||||
from odmantic import Model
|
||||
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
from odmantic import Model, ObjectId
|
||||
|
||||
|
||||
class ConfigFile(Model):
|
||||
codebase: Codebase
|
||||
type: str
|
||||
codebase_id: ObjectId
|
||||
type: str="config"
|
||||
path: str
|
||||
md5: str
|
||||
abstract: str
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from odmantic import Model
|
||||
from odmantic import Model, ObjectId
|
||||
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
|
||||
|
||||
class Directory(Model):
|
||||
codebase: Codebase
|
||||
codebase_id: ObjectId
|
||||
path: str
|
||||
md5: str
|
||||
abstract: str
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from odmantic import Model
|
||||
from odmantic import Model, ObjectId
|
||||
from typing import List
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
from db_models.Codebase import Codebase
|
||||
|
||||
|
||||
class Hotspot(Model):
|
||||
codebase: Codebase
|
||||
codebase_id: ObjectId
|
||||
topic: str
|
||||
links: List[int]
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from odmantic import Model
|
||||
|
||||
from db_models.embedded_models.Codebase import Codebase
|
||||
from odmantic import Model, ObjectId
|
||||
|
||||
|
||||
class IgnoreFile(Model):
|
||||
codebase: Codebase
|
||||
codebase_id: ObjectId
|
||||
type:str="ignore"
|
||||
path: str
|
||||
md5: str
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
from db_models.BinaryTool import BinaryTool
|
||||
from db_models.BinaryLibrary import BinaryLibrary
|
||||
from db_models.CodeFile import CodeFile
|
||||
from db_models.ConfigFile import ConfigFile
|
||||
from db_models.IgnoreFile import IgnoreFile
|
||||
from db_models.Codebase import Codebase
|
||||
@@ -1,9 +0,0 @@
|
||||
from odmantic import EmbeddedModel
|
||||
|
||||
|
||||
class Codebase(EmbeddedModel):
|
||||
name: str
|
||||
version: str
|
||||
branch: str
|
||||
path: str
|
||||
repo: str
|
||||
@@ -0,0 +1 @@
|
||||
from db_models.embedded_models.CodeSegment import CodeSegment
|
||||
@@ -1,14 +1,18 @@
|
||||
import os
|
||||
from threading import Lock
|
||||
from pymongo import MongoClient
|
||||
from typing import Optional
|
||||
from odmantic import AIOEngine
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_lock = Lock()
|
||||
_engine: Optional[AIOEngine] = None
|
||||
|
||||
|
||||
def init_db():
|
||||
global _client, _db
|
||||
global _client, _db, _engine
|
||||
if _client is None:
|
||||
with _lock:
|
||||
if _client is None:
|
||||
@@ -21,6 +25,11 @@ def init_db():
|
||||
if db_name not in _client.list_database_names():
|
||||
_client[db_name].create_collection('session_history')
|
||||
_db = _client[db_name]
|
||||
|
||||
if _engine is None:
|
||||
client = AsyncIOMotorClient(uri)
|
||||
_engine = AIOEngine(client=client, database=db_name)
|
||||
|
||||
return _db
|
||||
|
||||
|
||||
@@ -34,3 +43,16 @@ def get_client():
|
||||
if _client is None:
|
||||
init_db()
|
||||
return _client
|
||||
|
||||
|
||||
def get_engine() -> AIOEngine:
|
||||
if _engine is None:
|
||||
raise RuntimeError("Database not initialized. Call init_db first.")
|
||||
return _engine
|
||||
|
||||
|
||||
def close_db():
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
_engine.client.close()
|
||||
_engine = None
|
||||
|
||||
124
src/utils/db_connections/db_operations.py
Normal file
124
src/utils/db_connections/db_operations.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Optional, Any, List
|
||||
from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool, Codebase
|
||||
from utils.db_connections import get_engine
|
||||
from utils.ssh_connections.ssh_operations import check_branch_exists, get_file_md5 as get_workspace_file_md5
|
||||
|
||||
|
||||
async def get_codebase(name: str) -> Optional[Codebase]:
|
||||
engine = get_engine()
|
||||
branch = 'main'
|
||||
if '@' in name:
|
||||
name, branch = name.split('@')
|
||||
|
||||
code_base = await engine.find_one(
|
||||
Codebase,
|
||||
(Codebase.name == name) & (Codebase.branch == branch)
|
||||
)
|
||||
|
||||
if not code_base:
|
||||
code_base_main = await engine.find_one(
|
||||
Codebase,
|
||||
(Codebase.name == name) & (Codebase.branch == 'main')
|
||||
)
|
||||
|
||||
if not code_base_main:
|
||||
raise Exception(f"Codebase {name} not found")
|
||||
|
||||
branch_check = check_branch_exists(code_base_main.repo, branch)
|
||||
if branch_check.get("status") == "error":
|
||||
raise Exception(branch_check.get("result", "Error checking branch existence"))
|
||||
|
||||
if branch_check.get("status") == "success" and branch_check.get("result", False):
|
||||
new_code_base = Codebase(
|
||||
name=name,
|
||||
version=code_base_main.version,
|
||||
branch=branch,
|
||||
repo=code_base_main.repo
|
||||
)
|
||||
await engine.save(new_code_base)
|
||||
return new_code_base
|
||||
else:
|
||||
raise Exception(f"Branch {branch} not found for codebase {name}")
|
||||
|
||||
return code_base
|
||||
|
||||
|
||||
|
||||
async def get_file_md5(codebase: str, path: str) -> Optional[str]:
|
||||
doc = await get_file_document(codebase, path)
|
||||
if doc and doc.hasattr('md5'):
|
||||
return doc.md5
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_file_document(codebase: str, path: str):
|
||||
engine = get_engine()
|
||||
branch = 'main'
|
||||
if '@' in codebase:
|
||||
codebase, branch = codebase.split('@')
|
||||
|
||||
cb = await get_codebase(f"{codebase}@{branch}")
|
||||
if not cb:
|
||||
return None
|
||||
|
||||
all_files = []
|
||||
for model in [CodeFile, ConfigFile, BinaryLibrary, BinaryTool]:
|
||||
files = await engine.find(
|
||||
model,
|
||||
model.path == path
|
||||
)
|
||||
all_files.extend(files)
|
||||
|
||||
if not all_files:
|
||||
return None
|
||||
|
||||
codebase_files = [file for file in all_files if file.codebase_id == cb.id]
|
||||
|
||||
if codebase_files:
|
||||
return codebase_files[0]
|
||||
|
||||
workspace_path = f"~/workspace/{codebase}@{branch}/{path}"
|
||||
workspace_md5_result = get_workspace_file_md5(workspace_path)
|
||||
|
||||
if workspace_md5_result.get("status") == "success":
|
||||
workspace_md5 = workspace_md5_result.get("result")
|
||||
|
||||
for file in all_files:
|
||||
if hasattr(file, 'md5') and file.md5 == workspace_md5:
|
||||
new_file = type(file)(
|
||||
codebase_id=cb.id,
|
||||
path=file.path,
|
||||
**{k: v for k, v in file.dict().items()
|
||||
if k not in ['id', 'codebase_id', 'path']}
|
||||
)
|
||||
await engine.save(new_file)
|
||||
return new_file
|
||||
|
||||
return None
|
||||
|
||||
async def save_model(model: Any) -> bool:
|
||||
try:
|
||||
engine = get_engine()
|
||||
await engine.save(model)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error saving model: {e}")
|
||||
return False
|
||||
|
||||
async def update_branch_documents(from_codebase_id: str, to_codebase_id: str):
|
||||
engine = get_engine()
|
||||
for model in [CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool]:
|
||||
from_docs = await engine.find(model, model.codebase_id == from_codebase_id)
|
||||
|
||||
for doc in from_docs:
|
||||
to_docs = await engine.find(
|
||||
model,
|
||||
(model.path == doc.path) & (model.codebase_id == to_codebase_id)
|
||||
)
|
||||
|
||||
if to_docs:
|
||||
for to_doc in to_docs:
|
||||
await engine.delete(to_doc)
|
||||
|
||||
doc.codebase_id = to_codebase_id
|
||||
await save_model(doc)
|
||||
73
src/utils/operations.py
Normal file
73
src/utils/operations.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Optional
|
||||
|
||||
from db_models import Codebase
|
||||
from utils.db_connections.db_operations import get_codebase, save_model
|
||||
from utils.ssh_connections.ssh_operations import (
|
||||
clone_repo, check_branch_exists, create_and_push_branch,
|
||||
checkout_branch, merge_branches, push_changes, pull_changes, is_dir_exist
|
||||
)
|
||||
from utils.db_connections.db_operations import update_branch_documents
|
||||
|
||||
|
||||
async def get_codebase_root(name: str) -> Optional[str]:
|
||||
branch = 'main'
|
||||
if '@' in name:
|
||||
name, branch = name.split('@')
|
||||
exist = is_dir_exist(f'{name}@{branch}')
|
||||
if exist.get('status', 'error') == 'error':
|
||||
raise Exception(exist.get('result', ''))
|
||||
if not exist.get('result', False):
|
||||
try:
|
||||
cb = await get_codebase(f'{name}@{branch}')
|
||||
clone_repo(cb.repo, branch)
|
||||
except Exception as e:
|
||||
raise Exception(f"Codebase {name}@{branch} not found")
|
||||
return f'~/workspace/{name}@{branch}'
|
||||
|
||||
async def create_branch(codebase: str, from_branch: str, new_branch: str):
|
||||
cb = await get_codebase(f'{codebase}@{from_branch}')
|
||||
|
||||
branch_check = check_branch_exists(cb.repo, new_branch)
|
||||
if branch_check.get("status") == "error":
|
||||
raise Exception(branch_check.get("result", "Error checking branch existence"))
|
||||
|
||||
if branch_check.get("status") == "success" and branch_check.get("result", True):
|
||||
raise Exception(f"Branch {new_branch} already exists for codebase {codebase}")
|
||||
|
||||
create_result = create_and_push_branch(codebase, from_branch, new_branch)
|
||||
if create_result.get("status") != "success":
|
||||
raise Exception(f"Failed to create branch {new_branch}: {create_result.get('result', '')}")
|
||||
|
||||
checkout_result = checkout_branch(codebase, from_branch)
|
||||
if checkout_result.get("status") != "success":
|
||||
raise Exception(f"Failed to checkout back to branch {from_branch}: {checkout_result.get('result', '')}")
|
||||
|
||||
new_codebase = Codebase(
|
||||
name=codebase,
|
||||
version=cb.version,
|
||||
branch=new_branch,
|
||||
repo=cb.repo
|
||||
)
|
||||
await save_model(new_codebase)
|
||||
|
||||
async def merge_branch(codebase: str, from_branch: str, to_branch: str):
|
||||
from_cb = await get_codebase(f'{codebase}@{from_branch}')
|
||||
to_cb = await get_codebase(f'{codebase}@{to_branch}')
|
||||
|
||||
if not from_cb or not to_cb:
|
||||
raise Exception(f"One or both branches not found: {from_branch}, {to_branch}")
|
||||
|
||||
merge_result = merge_branches(codebase, from_branch, to_branch)
|
||||
if merge_result.get("status") != "success":
|
||||
raise Exception(f"Failed to merge branch {from_branch} into {to_branch}: {merge_result.get('result', '')}")
|
||||
|
||||
push_result = push_changes(codebase, to_branch)
|
||||
if push_result.get("status") != "success":
|
||||
raise Exception(f"Failed to push merged changes to remote: {push_result.get('result', '')}")
|
||||
|
||||
pull_result = pull_changes(codebase, to_branch)
|
||||
if pull_result.get("status") != "success":
|
||||
raise Exception(f"Failed to pull latest changes from remote: {pull_result.get('result', '')}")
|
||||
|
||||
await update_branch_documents(from_cb.id, to_cb.id)
|
||||
return {"status": "success", "message": f"Successfully merged branch {from_branch} into {to_branch}"}
|
||||
@@ -1,18 +1,18 @@
|
||||
import os
|
||||
import paramiko
|
||||
from threading import Lock
|
||||
from typing import Tuple, Optional, List, Dict, Any
|
||||
import json
|
||||
|
||||
from utils.ssh_connections.ssh_operations import is_dir_exist, make_dir
|
||||
|
||||
|
||||
class SSHConnectionManager:
|
||||
_clients = {}
|
||||
_lock = Lock()
|
||||
|
||||
HOST = os.getenv('SSH_HOST', 'host.docker.internal')
|
||||
USERNAME = os.getenv('SSH_USERNAME')
|
||||
PORT = os.getenv('SSH_PORT', 22)
|
||||
PASSWORD = os.getenv('SSH_PASSWORD')
|
||||
HOST = os.getenv('CKB_SSH_HOST', 'host.docker.internal')
|
||||
USERNAME = os.getenv('CKB_SSH_USERNAME')
|
||||
PORT = os.getenv('CKB_SSH_PORT', 22)
|
||||
PASSWORD = os.getenv('CKB_SSH_PASSWORD', None)
|
||||
|
||||
@classmethod
|
||||
def get_client(cls, timeout=10):
|
||||
@@ -34,79 +34,8 @@ class SSHConnectionManager:
|
||||
return cls._clients[key]
|
||||
|
||||
|
||||
def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]:
|
||||
client = SSHConnectionManager.get_client(timeout=timeout)
|
||||
stdin, stdout, stderr = client.exec_command(command, timeout=timeout)
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8')
|
||||
|
||||
|
||||
def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
files = sftp.listdir_attr(path)
|
||||
result = []
|
||||
for file in files:
|
||||
if not include_ignore and file.filename.startswith('.'):
|
||||
continue
|
||||
result.append({
|
||||
'name': file.filename,
|
||||
'size': file.st_size,
|
||||
'mode': file.st_mode,
|
||||
'mtime': file.st_mtime,
|
||||
'is_dir': file.st_mode & 0o40000 != 0
|
||||
})
|
||||
sftp.close()
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def read_file_content(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
with sftp.open(path, 'r') as f:
|
||||
content = f.read().decode('utf-8')
|
||||
sftp.close()
|
||||
return {"status": "success", "result": content}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def write_file_content(path: str, content: str) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
with sftp.open(path, 'w') as f:
|
||||
f.write(content)
|
||||
sftp.close()
|
||||
return {"status": "success", "result": None}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def get_file_md5(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
exit_code, stdout, stderr = execute_command(f"md5sum {path}")
|
||||
if exit_code == 0:
|
||||
md5 = stdout.split()[0]
|
||||
return {"status": "success", "result": md5}
|
||||
return {"status": "failure", "result": stderr}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
try:
|
||||
sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'"
|
||||
exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout)
|
||||
if exit_code == 0:
|
||||
return {"status": "success", "result": stdout}
|
||||
return {"status": "failure", "result": stderr}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
def init_workspace():
|
||||
make_dir('~/workspace')
|
||||
|
||||
|
||||
|
||||
|
||||
185
src/utils/ssh_connections/ssh_operations.py
Normal file
185
src/utils/ssh_connections/ssh_operations.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from utils.ssh_connections import SSHConnectionManager
|
||||
from typing import Tuple, Dict, Any
|
||||
|
||||
|
||||
def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]:
|
||||
client = SSHConnectionManager.get_client(timeout=timeout)
|
||||
stdin, stdout, stderr = client.exec_command(command, timeout=timeout)
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8')
|
||||
|
||||
|
||||
def _test(exit_code: int):
|
||||
if exit_code == 0:
|
||||
return {"status": "success", "result": True}
|
||||
return {"status": "failure", "result": False}
|
||||
|
||||
def make_dir(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
exit_code, stdout, stderr = execute_command(f"mkdir -p {path}")
|
||||
return _test(exit_code)
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
|
||||
def is_file_exist(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
exit_code, stdout, stderr = execute_command(f"test -f {path}")
|
||||
return _test(exit_code)
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
|
||||
def is_dir_exist(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
exit_code, stdout, stderr = execute_command(f"test -d {path}")
|
||||
return _test(exit_code)
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
|
||||
def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
files = sftp.listdir_attr(path)
|
||||
result = []
|
||||
for file in files:
|
||||
if not include_ignore and file.filename.startswith('.'):
|
||||
continue
|
||||
result.append({
|
||||
'name': file.filename,
|
||||
'size': file.st_size,
|
||||
'mode': file.st_mode,
|
||||
'mtime': file.st_mtime,
|
||||
'is_dir': file.st_mode & 0o40000 != 0
|
||||
})
|
||||
sftp.close()
|
||||
return {"status": "success", "result": result}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def read_file_content(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
with sftp.open(path, 'r') as f:
|
||||
content = f.read().decode('utf-8')
|
||||
sftp.close()
|
||||
return {"status": "success", "result": content}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def write_file_content(path: str, content: str) -> Dict[str, Any]:
|
||||
try:
|
||||
client = SSHConnectionManager.get_client()
|
||||
sftp = client.open_sftp()
|
||||
with sftp.open(path, 'w') as f:
|
||||
f.write(content)
|
||||
sftp.close()
|
||||
return {"status": "success", "result": None}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def get_file_md5(path: str) -> Dict[str, Any]:
|
||||
try:
|
||||
exit_code, stdout, stderr = execute_command(f"md5sum {path}")
|
||||
if exit_code == 0:
|
||||
md5 = stdout.split()[0]
|
||||
return {"status": "success", "result": md5}
|
||||
return {"status": "failure", "result": stderr}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
|
||||
def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
try:
|
||||
sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'"
|
||||
exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout)
|
||||
if exit_code == 0:
|
||||
return {"status": "success", "result": stdout}
|
||||
return {"status": "failure", "result": stderr}
|
||||
except Exception as e:
|
||||
return {"status": "failure", "result": str(e)}
|
||||
|
||||
def clone_repo(url: str, branch: str = "main") -> Dict[str, Any]:
|
||||
try:
|
||||
name = url.split("/")[-1].replace(".git", "")+f'@{branch}'
|
||||
exit_code, stdout, stderr = execute_command(f"git clone -b {branch} {url} ~/workspace/{name}", timeout=7200)
|
||||
return _test(exit_code)
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def remove_codebase(codebase: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if not '@' in codebase:
|
||||
codebase = f'{codebase}@main'
|
||||
exit_code, stdout, stdin = execute_command(f"rm -rf ~/workspace/{codebase}")
|
||||
return _test(exit_code)
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def check_branch_exists(repo_url: str, branch: str) -> Dict[str, Any]:
|
||||
try:
|
||||
command = f"git ls-remote --heads {repo_url} refs/heads/{branch} | wc -l"
|
||||
exit_code, stdout, stderr = execute_command(command)
|
||||
if exit_code == 0:
|
||||
count = int(stdout.strip())
|
||||
return {"status": "success", "result": count > 0}
|
||||
return {"status": "failure", "result": stderr}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def create_and_push_branch(codebase: str, from_branch: str, new_branch: str) -> Dict[str, Any]:
|
||||
|
||||
try:
|
||||
create_cmd = f"cd ~/workspace/{codebase}@{from_branch} && git checkout -b {new_branch} && git push origin {new_branch}"
|
||||
exit_code, stdout, stderr = execute_command(create_cmd)
|
||||
if exit_code != 0:
|
||||
return {"status": "failure", "result": stderr}
|
||||
return {"status": "success", "result": stdout}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def checkout_branch(codebase: str, branch: str) -> Dict[str, Any]:
|
||||
try:
|
||||
checkout_cmd = f"cd ~/workspace/{codebase}@{branch} && git checkout {branch}"
|
||||
exit_code, stdout, stderr = execute_command(checkout_cmd)
|
||||
if exit_code != 0:
|
||||
return {"status": "failure", "result": stderr}
|
||||
return {"status": "success", "result": stdout}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def merge_branches(codebase: str, from_branch: str, to_branch: str) -> Dict[str, Any]:
|
||||
try:
|
||||
merge_cmd = f"cd ~/workspace/{codebase}@{to_branch} && git fetch origin && git merge origin/{from_branch} -X theirs"
|
||||
exit_code, stdout, stderr = execute_command(merge_cmd)
|
||||
if exit_code != 0:
|
||||
return {"status": "failure", "result": stderr}
|
||||
return {"status": "success", "result": stdout}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def push_changes(codebase: str, branch: str) -> Dict[str, Any]:
|
||||
try:
|
||||
push_cmd = f"cd ~/workspace/{codebase}@{branch} && git push origin {branch}"
|
||||
exit_code, stdout, stderr = execute_command(push_cmd)
|
||||
if exit_code != 0:
|
||||
return {"status": "failure", "result": stderr}
|
||||
return {"status": "success", "result": stdout}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
|
||||
def pull_changes(codebase: str, branch: str) -> Dict[str, Any]:
|
||||
try:
|
||||
pull_cmd = f"cd ~/workspace/{codebase}@{branch} && git pull origin {branch}"
|
||||
exit_code, stdout, stderr = execute_command(pull_cmd)
|
||||
if exit_code != 0:
|
||||
return {"status": "failure", "result": stderr}
|
||||
return {"status": "success", "result": stdout}
|
||||
except Exception as e:
|
||||
return {"status": "error", "result": str(e)}
|
||||
Reference in New Issue
Block a user