diff --git a/app.py b/app.py index 5e5b796..ecec029 100644 --- a/app.py +++ b/app.py @@ -3,10 +3,11 @@ import threading from mcp_service import start_mcp from api_service import start_api from utils.db_connections import init_db +from utils.ssh_connections import init_workspace if __name__ == '__main__': init_db() - + init_workspace() t_mcp = threading.Thread(target=start_mcp, daemon=True) t_api = threading.Thread(target=start_api, daemon=True) t_mcp.start() diff --git a/l1.py b/l1.py index 27bff83..f64d0c2 100644 --- a/l1.py +++ b/l1.py @@ -1,13 +1,11 @@ import asyncio -from langchain.agents import create_react_agent, AgentExecutor -from langchain.chat_models import init_chat_model -from langchain_core.prompts import PromptTemplate -from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_core.tools import tool -from langchain.memory import MongoDBChatMessageHistory, ConversationBufferWindowMemory +from agents.agent_tasks.scan_file import tools from agents.agent_templates import AgentTemplate -from agents.output_parser.MCPReactParser import MCPReactParser +from utils.db_connections import init_db +from agents.prompts import general_prompt_param_builder +from utils.ssh_connections.ssh_operations import is_file_exist @tool @@ -21,8 +19,9 @@ def get_daily_string(salt: str): async def main(): + init_db() temp1 = AgentTemplate() - temp1.set_model("gpt-4.1-nano", "openai", "") + temp1.set_model("gpt-4.1-nano", "openai") temp1.set_mcp_server("terminal_tools", { 'transport': 'sse', 'headers': { @@ -30,89 +29,21 @@ async def main(): }, 'url': 'http://localhost:5050/sse' }) + temp1.set_builtin_tools(tools) + temp1.set_template_params(general_prompt_param_builder( + agent_role="assistant", + task_description="test, call update_code_result_to_db tool with arbitrary valid input" + )) - - mcp_client = MultiServerMCPClient({ - 'terminal_tool':{ - 'transport': 'sse', - 'headers': { - "X-Api-Key": "bd303c2f2a515b4ce70c1f07ee85530c2" - }, - 'url': 'http://localhost:5050/sse' - } + agent = await temp1.async_get_instance() + r1 = await agent.ainvoke({ + 'user_msg': 'test' }) - - tools = await mcp_client.get_tools() - xtools = [] - for t in tools: - if t.name == 'CreateTerminalSession': - xtools.append(t) - xtools.append(get_daily_string) - - print(xtools) - prompt = PromptTemplate.from_template(""" - - {system} - - You have access to the following tools: - - {tools} - - Use the following format: - - Question: the question you must answer - If you want to use tools: - Thought: always reason what to do - Action: the action to take, must be one of [{tool_names}] - Action Input: a **JSON object** with named arguments required by the tool - Observation: the result of the action - If no tool is needed: - Thought: what you are thinking - ... (this Thought/Action/... can repeat N times) - Final Answer: the final answer to the original question - - Here is the conversation history: - {chat_history} - - User message: {user} - {agent_scratchpad} - """) - op = MCPReactParser() - agent = create_react_agent(model_openai, xtools, prompt, output_parser=op) - message_history = MongoDBChatMessageHistory( - connection_string=MONGO_CONNECTION_STRING, - session_id=USER_SESSION_ID, - database_name=MONGO_DB_NAME, - collection_name=MONGO_COLLECTION_NAME - ) - mongodb_memory = ConversationBufferWindowMemory( - chat_memory=message_history, - memory_key=MEMORY_KEY, - k=10, - return_messages=False, - ) - agent_executor = AgentExecutor( - agent=agent, - tools=xtools, - memory=mongodb_memory, - handle_parsing_errors=True, - verbose=True, - - ) - response = await agent_executor.ainvoke({ - 'system': 'you are a helpful assistant', - 'user': 'hello, please create a terminal session with label \'x\' and show me the session id', - }) - print(response) - r2 = await agent_executor.ainvoke({ - 'system': 'you are a helpful assistant', - 'user': 'hello, please show me today\'s daily string' - }) - print(r2) - #print(s1) - #print(s2) - #print(s3) - #print(s4) + # for t in tools: + # if t.name == 'update_code_result_to_db': + # print(t) + print(r1) + print(is_file_exist("./Guide")) if __name__ == '__main__': asyncio.run(main()) \ No newline at end of file diff --git a/src/agents/tools/res_tools/__init__.py b/src/agents/agent_tasks/create_codebase.py similarity index 100% rename from src/agents/tools/res_tools/__init__.py rename to src/agents/agent_tasks/create_codebase.py diff --git a/src/agents/agent_tasks/scan_file.py b/src/agents/agent_tasks/scan_file.py index 06bcbcf..2ef68bb 100644 --- a/src/agents/agent_tasks/scan_file.py +++ b/src/agents/agent_tasks/scan_file.py @@ -1,12 +1,19 @@ from langchain_core.tools import tool +from typing import Dict, Optional, List, TypedDict +from utils import operations +from utils.db_connections import db_operations +from utils.ssh_connections.ssh_operations import read_file_content, get_file_md5 +from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool +from db_models.embedded_models import CodeSegment -scan_file_task_msg = """ -Your task is to scan a file and generate knowledge abstract for it +scan_file_task_description = """ +Your task is to understand a file and generate knowledge abstract for it code base and file path should be provided by user msg, if not, raise an error The workflow is: 1. determine if abstract of file exists in db - - if exists, check if the abstract is outdated + - if exists, check if the abstract is outdated + - if outdated, update the abstract - if not outdated, call res_tool_scan_result tool with status "success" and finish the workflow 2. determine the type of the file(source code, config file, binary lib/executable) - if the file is binary @@ -17,20 +24,290 @@ The workflow is: - breakdown code by scopes, blocks into segments with start line and end line - describe usage and/or importance of each segment with tool upload_result_to_db - call res_tool_scan_result tool with status "success" and finish the workflow - - """ +class CodeSegmentDict(TypedDict): + line_start: int + line_end: int + abstract: str + links: List[str] @tool -def res_tool_scan_result(session_id: str, status: str): +async def update_code_result_to_db( + codebase: str, + path: str, + md5: str, + abstract: str, + segments: List[CodeSegmentDict] +) -> bool: + """Update code file analysis results to database + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + md5: MD5 hash of the file + abstract: Overall summary of the file + segments: List of code segments, each containing: + - line_start: Starting line number + - line_end: Ending line number + - abstract: Summary of this segment + - links: List of related file links + + Returns: + bool: Whether the update was successful + + Raises: + ValueError: When the codebase does not exist + """ + codebase_obj = await db_operations.get_codebase(codebase) + if not codebase_obj: + raise ValueError(f"Codebase {codebase} not found") + code_segments = [ + CodeSegment( + line_start=seg["line_start"], + line_end=seg["line_end"], + abstract=seg["abstract"], + links=seg["links"] + ) + for seg in segments + ] + + code_file = CodeFile( + codebase=codebase_obj, + path=path, + md5=md5, + abstract=abstract, + segments=code_segments, + scanned=True + ) + + return await db_operations.save_model(code_file) + +@tool +async def update_config_result_to_db( + codebase: str, + path: str, + md5: str, + abstract: str +) -> bool: + """Update configuration file analysis results to database + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + md5: MD5 hash of the file + abstract: Configuration summary including main config items and usage description + + Returns: + bool: Whether the update was successful + + Raises: + ValueError: When the codebase does not exist + """ + codebase_obj = await db_operations.get_codebase(codebase) + if not codebase_obj: + raise ValueError(f"Codebase {codebase} not found") + + config_file = ConfigFile( + codebase=codebase_obj, + path=path, + md5=md5, + abstract=abstract, + scanned=True + ) + + return await db_operations.save_model(config_file) + +@tool +async def update_ignore_result_to_db( + codebase: str, + path: str, + md5: str +) -> bool: + """Update ignore file information to database + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + md5: MD5 hash of the file + + Returns: + bool: Whether the update was successful + + Raises: + ValueError: When the codebase does not exist + """ + codebase_obj = await db_operations.get_codebase(codebase) + if not codebase_obj: + raise ValueError(f"Codebase {codebase} not found") + + ignore_file = IgnoreFile( + codebase=codebase_obj, + path=path, + md5=md5 + ) + + return await db_operations.save_model(ignore_file) + +@tool +async def update_library_result_to_db( + codebase: str, + path: str, + abstract: str +) -> bool: + """Update binary library file analysis results to database + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + abstract: Library file summary including: + - Main functionality of the library + - Dependencies + - Usage scenarios + + Returns: + bool: Whether the update was successful + + Raises: + ValueError: When the codebase does not exist + """ + codebase_obj = await db_operations.get_codebase(codebase) + if not codebase_obj: + raise ValueError(f"Codebase {codebase} not found") + + library_file = BinaryLibrary( + codebase=codebase_obj, + path=path, + abstract=abstract + ) + + return await db_operations.save_model(library_file) + +@tool +async def update_tool_result_to_db( + codebase: str, + path: str, + abstract: str +) -> bool: + """Update binary tool file analysis results to database + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + abstract: Tool file summary including: + - Main functionality of the tool + - Command line argument descriptions + - Usage examples + + Returns: + bool: Whether the update was successful + + Raises: + ValueError: When the codebase does not exist + """ + codebase_obj = await db_operations.get_codebase(codebase) + if not codebase_obj: + raise ValueError(f"Codebase {codebase} not found") + + tool_file = BinaryTool( + codebase=codebase_obj, + path=path, + abstract=abstract + ) + + return await db_operations.save_model(tool_file) + +@tool +def required_response_scan_result(session_id: str, status: str) -> Dict[str, str]: + """Return scan result response + + Args: + session_id: ID of the scan session + status: Status of the scan + + Returns: + Dict[str, str]: Response containing session ID and status + """ return { 'session_id': session_id, 'status': status } +@tool +async def read_file(codebase: str, path: str) -> str: + """Read file content + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + + Returns: + str: File content + + Raises: + ValueError: When the codebase does not exist or file reading fails + """ + root_path = await operations.get_codebase_root(codebase) + if not root_path: + raise ValueError(f"Codebase {codebase} not found") + + full_path = f"{root_path}/{path}" + result = read_file_content(full_path) + + if result["status"] == "failure": + raise ValueError(f"Failed to read file: {result['result']}") + return result["result"] @tool -def upload_result_to_db(codebase: str, file_path: str, segments): - pass +async def is_outdated(codebase: str, path: str) -> bool: + """Check if file summary is outdated + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + + Returns: + bool: True if file does not exist or MD5 does not match + """ + stored_md5 = await db_operations.get_file_md5(codebase, path) + if not stored_md5: + return True + + root_path = await operations.get_codebase_root(codebase) + if not root_path: + raise ValueError(f"Codebase {codebase} not found") + + full_path = f"{root_path}/{path}" + result = get_file_md5(full_path) + + if result["status"] == "failure": + raise ValueError(f"Failed to get file MD5: {result['result']}") + return stored_md5 != result["result"] + +@tool +async def get_abstract(codebase: str, path: str) -> Optional[str]: + """Get file summary + + Args: + codebase: Name of the codebase + path: File path relative to codebase root + + Returns: + Optional[str]: File summary, None if it doesn't exist + """ + return await db_operations.get_file_document(codebase, path) + + + +tools = [ + update_code_result_to_db, + update_config_result_to_db, + update_ignore_result_to_db, + update_library_result_to_db, + update_tool_result_to_db, + required_response_scan_result, + read_file, + is_outdated, + get_abstract +] diff --git a/src/agents/agent_templates/__init__.py b/src/agents/agent_templates/__init__.py index 6c561eb..ebe1f33 100644 --- a/src/agents/agent_templates/__init__.py +++ b/src/agents/agent_templates/__init__.py @@ -59,12 +59,13 @@ class AgentTemplate: mcp_client = MultiServerMCPClient(self.mcp_servers) tools = await mcp_client.get_tools() + self.builtin_tools prompt_template = self.prompt_template + prompt_template = prompt_template.replace('{session_id}', session_id) for param in self.template_params.keys(): prompt_template = prompt_template.replace(f'{{{param}}}', self.template_params[param]) - prompt = PromptTemplate.from_tools(prompt_template) + prompt = PromptTemplate.from_template(prompt_template) history = MongoDBChatMessageHistory( - connection_string=os.getenv("MONGODB_CONNECTION_STRING", ""), + connection_string=os.getenv("MONGO_CONNECTION_STRING", None), session_id=session_id, database_name="ckb", collection_name="session_history" @@ -72,7 +73,7 @@ class AgentTemplate: memory = ConversationBufferWindowMemory( chat_memory=history, - memory_key="chat_memory", + memory_key="chat_history", k=self.history_window_size, return_messages=False, ) diff --git a/src/agents/agent_templates/task_agent_templates/__init__.py b/src/agents/agent_templates/task_agent_templates/__init__.py new file mode 100644 index 0000000..e79bfde --- /dev/null +++ b/src/agents/agent_templates/task_agent_templates/__init__.py @@ -0,0 +1,11 @@ +from agents.agent_tasks.scan_file import tools, scan_file_task_description +from agents.agent_templates import AgentTemplate +from agents.prompts import general_prompt_param_builder + +scan_file_agent_template = AgentTemplate() +scan_file_agent_template.set_model("gpt-4.1-nano", "openai") +scan_file_agent_template.set_builtin_tools(tools) +scan_file_agent_template.set_template_params(general_prompt_param_builder( + agent_role="assistant", + task_description=scan_file_task_description +)) diff --git a/src/agents/output_parser/MCPReactParser.py b/src/agents/output_parser/MCPReactParser.py index 811fb7c..d8b4d9e 100644 --- a/src/agents/output_parser/MCPReactParser.py +++ b/src/agents/output_parser/MCPReactParser.py @@ -16,7 +16,7 @@ class MCPReactParser(AgentOutputParser): return self.parse(text) def parse(self, text: str) -> Union[AgentAction, AgentFinish]: - includes_final_answer = self.Final_Answer_ACTION in text + includes_final_answer = self.FINAL_ANSWER_ACTION in text regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" action_match = re.search(regex, text, re.DOTALL) diff --git a/src/agents/prompts/__init__.py b/src/agents/prompts/__init__.py index 2612219..f508577 100644 --- a/src/agents/prompts/__init__.py +++ b/src/agents/prompts/__init__.py @@ -1,7 +1,6 @@ general_prompt = """ You are {agent_role} - - Your current task is +Session Id of current task is {session_id} {task_description} @@ -29,10 +28,11 @@ You are {agent_role} Final Answer: the final answer to the original question + User message: {user_msg} + Here is the conversation history: {chat_history} - User message: {user_msg} {agent_scratchpad} """ diff --git a/src/agents/tools/res_tools/res_tool_general_response.py b/src/agents/tools/res_tools/res_tool_general_response.py deleted file mode 100644 index 90a65c2..0000000 --- a/src/agents/tools/res_tools/res_tool_general_response.py +++ /dev/null @@ -1,9 +0,0 @@ -from langchain_core.tools import tool - -@tool -def res_tool_general_response(session_id: str, response: str): - return { - 'type': 'response', - 'session_id': session_id, - 'response': response - } \ No newline at end of file diff --git a/src/db_models/BinaryLibrary.py b/src/db_models/BinaryLibrary.py index c7782db..50bc5bf 100644 --- a/src/db_models/BinaryLibrary.py +++ b/src/db_models/BinaryLibrary.py @@ -1,9 +1,8 @@ -from odmantic import Model - -from db_models.embedded_models.Codebase import Codebase - +from odmantic import Model, ObjectId class BinaryLibrary(Model): - codebase: Codebase + codebase_id: ObjectId + type: str="library" path: str - abstract: str \ No newline at end of file + abstract: str + md5: str diff --git a/src/db_models/BinaryTool.py b/src/db_models/BinaryTool.py index 238b340..c4981b5 100644 --- a/src/db_models/BinaryTool.py +++ b/src/db_models/BinaryTool.py @@ -1,9 +1,9 @@ -from odmantic import Model - -from db_models.embedded_models.Codebase import Codebase +from odmantic import Model, ObjectId class BinaryTool(Model): - codebase: Codebase + codebase_id: ObjectId + type: str="tool" path: str abstract: str + md5: str \ No newline at end of file diff --git a/src/db_models/CodeFile.py b/src/db_models/CodeFile.py index 5d41cf1..1ccd81b 100644 --- a/src/db_models/CodeFile.py +++ b/src/db_models/CodeFile.py @@ -1,13 +1,12 @@ -from odmantic import Model +from odmantic import Model, ObjectId from typing import List from db_models.embedded_models.CodeSegment import CodeSegment -from db_models.embedded_models.Codebase import Codebase class CodeFile(Model): - codebase: Codebase - type: str + codebase_id: ObjectId + type: str="code" path: str md5: str abstract: str diff --git a/src/db_models/Codebase.py b/src/db_models/Codebase.py new file mode 100644 index 0000000..25dfc28 --- /dev/null +++ b/src/db_models/Codebase.py @@ -0,0 +1,8 @@ +from odmantic import Model + + +class Codebase(Model): + name: str + version: str + branch: str + repo: str diff --git a/src/db_models/ConfigFile.py b/src/db_models/ConfigFile.py index 915e93d..8e05c08 100644 --- a/src/db_models/ConfigFile.py +++ b/src/db_models/ConfigFile.py @@ -1,11 +1,9 @@ -from odmantic import Model - -from db_models.embedded_models.Codebase import Codebase +from odmantic import Model, ObjectId class ConfigFile(Model): - codebase: Codebase - type: str + codebase_id: ObjectId + type: str="config" path: str md5: str abstract: str diff --git a/src/db_models/Directory.py b/src/db_models/Directory.py index edf6229..b0f0544 100644 --- a/src/db_models/Directory.py +++ b/src/db_models/Directory.py @@ -1,10 +1,9 @@ -from odmantic import Model +from odmantic import Model, ObjectId -from db_models.embedded_models.Codebase import Codebase class Directory(Model): - codebase: Codebase + codebase_id: ObjectId path: str md5: str abstract: str diff --git a/src/db_models/Hotspot.py b/src/db_models/Hotspot.py index f2cd298..105e732 100644 --- a/src/db_models/Hotspot.py +++ b/src/db_models/Hotspot.py @@ -1,9 +1,9 @@ -from odmantic import Model +from odmantic import Model, ObjectId from typing import List -from db_models.embedded_models.Codebase import Codebase +from db_models.Codebase import Codebase class Hotspot(Model): - codebase: Codebase + codebase_id: ObjectId topic: str links: List[int] diff --git a/src/db_models/IgnoreFile.py b/src/db_models/IgnoreFile.py index 1343b9f..a403fed 100644 --- a/src/db_models/IgnoreFile.py +++ b/src/db_models/IgnoreFile.py @@ -1,9 +1,8 @@ -from odmantic import Model - -from db_models.embedded_models.Codebase import Codebase +from odmantic import Model, ObjectId class IgnoreFile(Model): - codebase: Codebase + codebase_id: ObjectId + type:str="ignore" path: str md5: str diff --git a/src/db_models/__init__.py b/src/db_models/__init__.py index e69de29..764775c 100644 --- a/src/db_models/__init__.py +++ b/src/db_models/__init__.py @@ -0,0 +1,6 @@ +from db_models.BinaryTool import BinaryTool +from db_models.BinaryLibrary import BinaryLibrary +from db_models.CodeFile import CodeFile +from db_models.ConfigFile import ConfigFile +from db_models.IgnoreFile import IgnoreFile +from db_models.Codebase import Codebase \ No newline at end of file diff --git a/src/db_models/embedded_models/Codebase.py b/src/db_models/embedded_models/Codebase.py deleted file mode 100644 index 1a8488b..0000000 --- a/src/db_models/embedded_models/Codebase.py +++ /dev/null @@ -1,9 +0,0 @@ -from odmantic import EmbeddedModel - - -class Codebase(EmbeddedModel): - name: str - version: str - branch: str - path: str - repo: str diff --git a/src/db_models/embedded_models/__init__.py b/src/db_models/embedded_models/__init__.py index e69de29..70a26c5 100644 --- a/src/db_models/embedded_models/__init__.py +++ b/src/db_models/embedded_models/__init__.py @@ -0,0 +1 @@ +from db_models.embedded_models.CodeSegment import CodeSegment \ No newline at end of file diff --git a/src/utils/db_connections/__init__.py b/src/utils/db_connections/__init__.py index 93ceed1..4c3db12 100644 --- a/src/utils/db_connections/__init__.py +++ b/src/utils/db_connections/__init__.py @@ -1,14 +1,18 @@ import os from threading import Lock from pymongo import MongoClient +from typing import Optional +from odmantic import AIOEngine +from motor.motor_asyncio import AsyncIOMotorClient _client = None _db = None _lock = Lock() +_engine: Optional[AIOEngine] = None def init_db(): - global _client, _db + global _client, _db, _engine if _client is None: with _lock: if _client is None: @@ -21,6 +25,11 @@ def init_db(): if db_name not in _client.list_database_names(): _client[db_name].create_collection('session_history') _db = _client[db_name] + + if _engine is None: + client = AsyncIOMotorClient(uri) + _engine = AIOEngine(client=client, database=db_name) + return _db @@ -34,3 +43,16 @@ def get_client(): if _client is None: init_db() return _client + + +def get_engine() -> AIOEngine: + if _engine is None: + raise RuntimeError("Database not initialized. Call init_db first.") + return _engine + + +def close_db(): + global _engine + if _engine is not None: + _engine.client.close() + _engine = None diff --git a/src/utils/db_connections/db_operations.py b/src/utils/db_connections/db_operations.py new file mode 100644 index 0000000..9855cf9 --- /dev/null +++ b/src/utils/db_connections/db_operations.py @@ -0,0 +1,124 @@ +from typing import Optional, Any, List +from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool, Codebase +from utils.db_connections import get_engine +from utils.ssh_connections.ssh_operations import check_branch_exists, get_file_md5 as get_workspace_file_md5 + + +async def get_codebase(name: str) -> Optional[Codebase]: + engine = get_engine() + branch = 'main' + if '@' in name: + name, branch = name.split('@') + + code_base = await engine.find_one( + Codebase, + (Codebase.name == name) & (Codebase.branch == branch) + ) + + if not code_base: + code_base_main = await engine.find_one( + Codebase, + (Codebase.name == name) & (Codebase.branch == 'main') + ) + + if not code_base_main: + raise Exception(f"Codebase {name} not found") + + branch_check = check_branch_exists(code_base_main.repo, branch) + if branch_check.get("status") == "error": + raise Exception(branch_check.get("result", "Error checking branch existence")) + + if branch_check.get("status") == "success" and branch_check.get("result", False): + new_code_base = Codebase( + name=name, + version=code_base_main.version, + branch=branch, + repo=code_base_main.repo + ) + await engine.save(new_code_base) + return new_code_base + else: + raise Exception(f"Branch {branch} not found for codebase {name}") + + return code_base + + + +async def get_file_md5(codebase: str, path: str) -> Optional[str]: + doc = await get_file_document(codebase, path) + if doc and doc.hasattr('md5'): + return doc.md5 + else: + return None + +async def get_file_document(codebase: str, path: str): + engine = get_engine() + branch = 'main' + if '@' in codebase: + codebase, branch = codebase.split('@') + + cb = await get_codebase(f"{codebase}@{branch}") + if not cb: + return None + + all_files = [] + for model in [CodeFile, ConfigFile, BinaryLibrary, BinaryTool]: + files = await engine.find( + model, + model.path == path + ) + all_files.extend(files) + + if not all_files: + return None + + codebase_files = [file for file in all_files if file.codebase_id == cb.id] + + if codebase_files: + return codebase_files[0] + + workspace_path = f"~/workspace/{codebase}@{branch}/{path}" + workspace_md5_result = get_workspace_file_md5(workspace_path) + + if workspace_md5_result.get("status") == "success": + workspace_md5 = workspace_md5_result.get("result") + + for file in all_files: + if hasattr(file, 'md5') and file.md5 == workspace_md5: + new_file = type(file)( + codebase_id=cb.id, + path=file.path, + **{k: v for k, v in file.dict().items() + if k not in ['id', 'codebase_id', 'path']} + ) + await engine.save(new_file) + return new_file + + return None + +async def save_model(model: Any) -> bool: + try: + engine = get_engine() + await engine.save(model) + return True + except Exception as e: + print(f"Error saving model: {e}") + return False + +async def update_branch_documents(from_codebase_id: str, to_codebase_id: str): + engine = get_engine() + for model in [CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool]: + from_docs = await engine.find(model, model.codebase_id == from_codebase_id) + + for doc in from_docs: + to_docs = await engine.find( + model, + (model.path == doc.path) & (model.codebase_id == to_codebase_id) + ) + + if to_docs: + for to_doc in to_docs: + await engine.delete(to_doc) + + doc.codebase_id = to_codebase_id + await save_model(doc) diff --git a/src/utils/operations.py b/src/utils/operations.py new file mode 100644 index 0000000..7f4668b --- /dev/null +++ b/src/utils/operations.py @@ -0,0 +1,73 @@ +from typing import Optional + +from db_models import Codebase +from utils.db_connections.db_operations import get_codebase, save_model +from utils.ssh_connections.ssh_operations import ( + clone_repo, check_branch_exists, create_and_push_branch, + checkout_branch, merge_branches, push_changes, pull_changes, is_dir_exist +) +from utils.db_connections.db_operations import update_branch_documents + + +async def get_codebase_root(name: str) -> Optional[str]: + branch = 'main' + if '@' in name: + name, branch = name.split('@') + exist = is_dir_exist(f'{name}@{branch}') + if exist.get('status', 'error') == 'error': + raise Exception(exist.get('result', '')) + if not exist.get('result', False): + try: + cb = await get_codebase(f'{name}@{branch}') + clone_repo(cb.repo, branch) + except Exception as e: + raise Exception(f"Codebase {name}@{branch} not found") + return f'~/workspace/{name}@{branch}' + +async def create_branch(codebase: str, from_branch: str, new_branch: str): + cb = await get_codebase(f'{codebase}@{from_branch}') + + branch_check = check_branch_exists(cb.repo, new_branch) + if branch_check.get("status") == "error": + raise Exception(branch_check.get("result", "Error checking branch existence")) + + if branch_check.get("status") == "success" and branch_check.get("result", True): + raise Exception(f"Branch {new_branch} already exists for codebase {codebase}") + + create_result = create_and_push_branch(codebase, from_branch, new_branch) + if create_result.get("status") != "success": + raise Exception(f"Failed to create branch {new_branch}: {create_result.get('result', '')}") + + checkout_result = checkout_branch(codebase, from_branch) + if checkout_result.get("status") != "success": + raise Exception(f"Failed to checkout back to branch {from_branch}: {checkout_result.get('result', '')}") + + new_codebase = Codebase( + name=codebase, + version=cb.version, + branch=new_branch, + repo=cb.repo + ) + await save_model(new_codebase) + +async def merge_branch(codebase: str, from_branch: str, to_branch: str): + from_cb = await get_codebase(f'{codebase}@{from_branch}') + to_cb = await get_codebase(f'{codebase}@{to_branch}') + + if not from_cb or not to_cb: + raise Exception(f"One or both branches not found: {from_branch}, {to_branch}") + + merge_result = merge_branches(codebase, from_branch, to_branch) + if merge_result.get("status") != "success": + raise Exception(f"Failed to merge branch {from_branch} into {to_branch}: {merge_result.get('result', '')}") + + push_result = push_changes(codebase, to_branch) + if push_result.get("status") != "success": + raise Exception(f"Failed to push merged changes to remote: {push_result.get('result', '')}") + + pull_result = pull_changes(codebase, to_branch) + if pull_result.get("status") != "success": + raise Exception(f"Failed to pull latest changes from remote: {pull_result.get('result', '')}") + + await update_branch_documents(from_cb.id, to_cb.id) + return {"status": "success", "message": f"Successfully merged branch {from_branch} into {to_branch}"} diff --git a/src/utils/ssh_connections/__init__.py b/src/utils/ssh_connections/__init__.py index 5538186..e9c7960 100644 --- a/src/utils/ssh_connections/__init__.py +++ b/src/utils/ssh_connections/__init__.py @@ -1,18 +1,18 @@ import os import paramiko from threading import Lock -from typing import Tuple, Optional, List, Dict, Any -import json + +from utils.ssh_connections.ssh_operations import is_dir_exist, make_dir class SSHConnectionManager: _clients = {} _lock = Lock() - HOST = os.getenv('SSH_HOST', 'host.docker.internal') - USERNAME = os.getenv('SSH_USERNAME') - PORT = os.getenv('SSH_PORT', 22) - PASSWORD = os.getenv('SSH_PASSWORD') + HOST = os.getenv('CKB_SSH_HOST', 'host.docker.internal') + USERNAME = os.getenv('CKB_SSH_USERNAME') + PORT = os.getenv('CKB_SSH_PORT', 22) + PASSWORD = os.getenv('CKB_SSH_PASSWORD', None) @classmethod def get_client(cls, timeout=10): @@ -34,79 +34,8 @@ class SSHConnectionManager: return cls._clients[key] -def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]: - client = SSHConnectionManager.get_client(timeout=timeout) - stdin, stdout, stderr = client.exec_command(command, timeout=timeout) - exit_code = stdout.channel.recv_exit_status() - return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8') - - -def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]: - try: - client = SSHConnectionManager.get_client() - sftp = client.open_sftp() - files = sftp.listdir_attr(path) - result = [] - for file in files: - if not include_ignore and file.filename.startswith('.'): - continue - result.append({ - 'name': file.filename, - 'size': file.st_size, - 'mode': file.st_mode, - 'mtime': file.st_mtime, - 'is_dir': file.st_mode & 0o40000 != 0 - }) - sftp.close() - return {"status": "success", "result": result} - except Exception as e: - return {"status": "failure", "result": str(e)} - - -def read_file_content(path: str) -> Dict[str, Any]: - try: - client = SSHConnectionManager.get_client() - sftp = client.open_sftp() - with sftp.open(path, 'r') as f: - content = f.read().decode('utf-8') - sftp.close() - return {"status": "success", "result": content} - except Exception as e: - return {"status": "failure", "result": str(e)} - - -def write_file_content(path: str, content: str) -> Dict[str, Any]: - try: - client = SSHConnectionManager.get_client() - sftp = client.open_sftp() - with sftp.open(path, 'w') as f: - f.write(content) - sftp.close() - return {"status": "success", "result": None} - except Exception as e: - return {"status": "failure", "result": str(e)} - - -def get_file_md5(path: str) -> Dict[str, Any]: - try: - exit_code, stdout, stderr = execute_command(f"md5sum {path}") - if exit_code == 0: - md5 = stdout.split()[0] - return {"status": "success", "result": md5} - return {"status": "failure", "result": stderr} - except Exception as e: - return {"status": "failure", "result": str(e)} - - -def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]: - try: - sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'" - exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout) - if exit_code == 0: - return {"status": "success", "result": stdout} - return {"status": "failure", "result": stderr} - except Exception as e: - return {"status": "failure", "result": str(e)} +def init_workspace(): + make_dir('~/workspace') diff --git a/src/utils/ssh_connections/ssh_operations.py b/src/utils/ssh_connections/ssh_operations.py new file mode 100644 index 0000000..a0ff963 --- /dev/null +++ b/src/utils/ssh_connections/ssh_operations.py @@ -0,0 +1,185 @@ +from utils.ssh_connections import SSHConnectionManager +from typing import Tuple, Dict, Any + + +def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]: + client = SSHConnectionManager.get_client(timeout=timeout) + stdin, stdout, stderr = client.exec_command(command, timeout=timeout) + exit_code = stdout.channel.recv_exit_status() + return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8') + + +def _test(exit_code: int): + if exit_code == 0: + return {"status": "success", "result": True} + return {"status": "failure", "result": False} + +def make_dir(path: str) -> Dict[str, Any]: + try: + exit_code, stdout, stderr = execute_command(f"mkdir -p {path}") + return _test(exit_code) + except Exception as e: + return {"status": "error", "result": str(e)} + + +def is_file_exist(path: str) -> Dict[str, Any]: + try: + exit_code, stdout, stderr = execute_command(f"test -f {path}") + return _test(exit_code) + except Exception as e: + return {"status": "error", "result": str(e)} + + +def is_dir_exist(path: str) -> Dict[str, Any]: + try: + exit_code, stdout, stderr = execute_command(f"test -d {path}") + return _test(exit_code) + except Exception as e: + return {"status": "error", "result": str(e)} + + +def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]: + try: + client = SSHConnectionManager.get_client() + sftp = client.open_sftp() + files = sftp.listdir_attr(path) + result = [] + for file in files: + if not include_ignore and file.filename.startswith('.'): + continue + result.append({ + 'name': file.filename, + 'size': file.st_size, + 'mode': file.st_mode, + 'mtime': file.st_mtime, + 'is_dir': file.st_mode & 0o40000 != 0 + }) + sftp.close() + return {"status": "success", "result": result} + except Exception as e: + return {"status": "failure", "result": str(e)} + + +def read_file_content(path: str) -> Dict[str, Any]: + try: + client = SSHConnectionManager.get_client() + sftp = client.open_sftp() + with sftp.open(path, 'r') as f: + content = f.read().decode('utf-8') + sftp.close() + return {"status": "success", "result": content} + except Exception as e: + return {"status": "failure", "result": str(e)} + + +def write_file_content(path: str, content: str) -> Dict[str, Any]: + try: + client = SSHConnectionManager.get_client() + sftp = client.open_sftp() + with sftp.open(path, 'w') as f: + f.write(content) + sftp.close() + return {"status": "success", "result": None} + except Exception as e: + return {"status": "failure", "result": str(e)} + + +def get_file_md5(path: str) -> Dict[str, Any]: + try: + exit_code, stdout, stderr = execute_command(f"md5sum {path}") + if exit_code == 0: + md5 = stdout.split()[0] + return {"status": "success", "result": md5} + return {"status": "failure", "result": stderr} + except Exception as e: + return {"status": "failure", "result": str(e)} + + +def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]: + try: + sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'" + exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout) + if exit_code == 0: + return {"status": "success", "result": stdout} + return {"status": "failure", "result": stderr} + except Exception as e: + return {"status": "failure", "result": str(e)} + +def clone_repo(url: str, branch: str = "main") -> Dict[str, Any]: + try: + name = url.split("/")[-1].replace(".git", "")+f'@{branch}' + exit_code, stdout, stderr = execute_command(f"git clone -b {branch} {url} ~/workspace/{name}", timeout=7200) + return _test(exit_code) + except Exception as e: + return {"status": "error", "result": str(e)} + +def remove_codebase(codebase: str) -> Dict[str, Any]: + try: + if not '@' in codebase: + codebase = f'{codebase}@main' + exit_code, stdout, stdin = execute_command(f"rm -rf ~/workspace/{codebase}") + return _test(exit_code) + except Exception as e: + return {"status": "error", "result": str(e)} + +def check_branch_exists(repo_url: str, branch: str) -> Dict[str, Any]: + try: + command = f"git ls-remote --heads {repo_url} refs/heads/{branch} | wc -l" + exit_code, stdout, stderr = execute_command(command) + if exit_code == 0: + count = int(stdout.strip()) + return {"status": "success", "result": count > 0} + return {"status": "failure", "result": stderr} + except Exception as e: + return {"status": "error", "result": str(e)} + +def create_and_push_branch(codebase: str, from_branch: str, new_branch: str) -> Dict[str, Any]: + + try: + create_cmd = f"cd ~/workspace/{codebase}@{from_branch} && git checkout -b {new_branch} && git push origin {new_branch}" + exit_code, stdout, stderr = execute_command(create_cmd) + if exit_code != 0: + return {"status": "failure", "result": stderr} + return {"status": "success", "result": stdout} + except Exception as e: + return {"status": "error", "result": str(e)} + +def checkout_branch(codebase: str, branch: str) -> Dict[str, Any]: + try: + checkout_cmd = f"cd ~/workspace/{codebase}@{branch} && git checkout {branch}" + exit_code, stdout, stderr = execute_command(checkout_cmd) + if exit_code != 0: + return {"status": "failure", "result": stderr} + return {"status": "success", "result": stdout} + except Exception as e: + return {"status": "error", "result": str(e)} + +def merge_branches(codebase: str, from_branch: str, to_branch: str) -> Dict[str, Any]: + try: + merge_cmd = f"cd ~/workspace/{codebase}@{to_branch} && git fetch origin && git merge origin/{from_branch} -X theirs" + exit_code, stdout, stderr = execute_command(merge_cmd) + if exit_code != 0: + return {"status": "failure", "result": stderr} + return {"status": "success", "result": stdout} + except Exception as e: + return {"status": "error", "result": str(e)} + +def push_changes(codebase: str, branch: str) -> Dict[str, Any]: + try: + push_cmd = f"cd ~/workspace/{codebase}@{branch} && git push origin {branch}" + exit_code, stdout, stderr = execute_command(push_cmd) + if exit_code != 0: + return {"status": "failure", "result": stderr} + return {"status": "success", "result": stdout} + except Exception as e: + return {"status": "error", "result": str(e)} + +def pull_changes(codebase: str, branch: str) -> Dict[str, Any]: + try: + pull_cmd = f"cd ~/workspace/{codebase}@{branch} && git pull origin {branch}" + exit_code, stdout, stderr = execute_command(pull_cmd) + if exit_code != 0: + return {"status": "failure", "result": stderr} + return {"status": "success", "result": stdout} + except Exception as e: + return {"status": "error", "result": str(e)}