add: operations

This commit is contained in:
h z
2025-05-28 16:34:20 +01:00
parent 8933b13000
commit 9779630028
25 changed files with 776 additions and 231 deletions

3
app.py
View File

@@ -3,10 +3,11 @@ import threading
from mcp_service import start_mcp from mcp_service import start_mcp
from api_service import start_api from api_service import start_api
from utils.db_connections import init_db from utils.db_connections import init_db
from utils.ssh_connections import init_workspace
if __name__ == '__main__': if __name__ == '__main__':
init_db() init_db()
init_workspace()
t_mcp = threading.Thread(target=start_mcp, daemon=True) t_mcp = threading.Thread(target=start_mcp, daemon=True)
t_api = threading.Thread(target=start_api, daemon=True) t_api = threading.Thread(target=start_api, daemon=True)
t_mcp.start() t_mcp.start()

107
l1.py
View File

@@ -1,13 +1,11 @@
import asyncio import asyncio
from langchain.agents import create_react_agent, AgentExecutor
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain.memory import MongoDBChatMessageHistory, ConversationBufferWindowMemory
from agents.agent_tasks.scan_file import tools
from agents.agent_templates import AgentTemplate from agents.agent_templates import AgentTemplate
from agents.output_parser.MCPReactParser import MCPReactParser from utils.db_connections import init_db
from agents.prompts import general_prompt_param_builder
from utils.ssh_connections.ssh_operations import is_file_exist
@tool @tool
@@ -21,8 +19,9 @@ def get_daily_string(salt: str):
async def main(): async def main():
init_db()
temp1 = AgentTemplate() temp1 = AgentTemplate()
temp1.set_model("gpt-4.1-nano", "openai", "") temp1.set_model("gpt-4.1-nano", "openai")
temp1.set_mcp_server("terminal_tools", { temp1.set_mcp_server("terminal_tools", {
'transport': 'sse', 'transport': 'sse',
'headers': { 'headers': {
@@ -30,89 +29,21 @@ async def main():
}, },
'url': 'http://localhost:5050/sse' 'url': 'http://localhost:5050/sse'
}) })
temp1.set_builtin_tools(tools)
temp1.set_template_params(general_prompt_param_builder(
agent_role="assistant",
task_description="test, call update_code_result_to_db tool with arbitrary valid input"
))
agent = await temp1.async_get_instance()
mcp_client = MultiServerMCPClient({ r1 = await agent.ainvoke({
'terminal_tool':{ 'user_msg': 'test'
'transport': 'sse',
'headers': {
"X-Api-Key": "bd303c2f2a515b4ce70c1f07ee85530c2"
},
'url': 'http://localhost:5050/sse'
}
}) })
# for t in tools:
tools = await mcp_client.get_tools() # if t.name == 'update_code_result_to_db':
xtools = [] # print(t)
for t in tools: print(r1)
if t.name == 'CreateTerminalSession': print(is_file_exist("./Guide"))
xtools.append(t)
xtools.append(get_daily_string)
print(xtools)
prompt = PromptTemplate.from_template("""
{system}
You have access to the following tools:
{tools}
Use the following format:
Question: the question you must answer
If you want to use tools:
Thought: always reason what to do
Action: the action to take, must be one of [{tool_names}]
Action Input: a **JSON object** with named arguments required by the tool
Observation: the result of the action
If no tool is needed:
Thought: what you are thinking
... (this Thought/Action/... can repeat N times)
Final Answer: the final answer to the original question
Here is the conversation history:
{chat_history}
User message: {user}
{agent_scratchpad}
""")
op = MCPReactParser()
agent = create_react_agent(model_openai, xtools, prompt, output_parser=op)
message_history = MongoDBChatMessageHistory(
connection_string=MONGO_CONNECTION_STRING,
session_id=USER_SESSION_ID,
database_name=MONGO_DB_NAME,
collection_name=MONGO_COLLECTION_NAME
)
mongodb_memory = ConversationBufferWindowMemory(
chat_memory=message_history,
memory_key=MEMORY_KEY,
k=10,
return_messages=False,
)
agent_executor = AgentExecutor(
agent=agent,
tools=xtools,
memory=mongodb_memory,
handle_parsing_errors=True,
verbose=True,
)
response = await agent_executor.ainvoke({
'system': 'you are a helpful assistant',
'user': 'hello, please create a terminal session with label \'x\' and show me the session id',
})
print(response)
r2 = await agent_executor.ainvoke({
'system': 'you are a helpful assistant',
'user': 'hello, please show me today\'s daily string'
})
print(r2)
#print(s1)
#print(s2)
#print(s3)
#print(s4)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,12 +1,19 @@
from langchain_core.tools import tool from langchain_core.tools import tool
from typing import Dict, Optional, List, TypedDict
from utils import operations
from utils.db_connections import db_operations
from utils.ssh_connections.ssh_operations import read_file_content, get_file_md5
from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool
from db_models.embedded_models import CodeSegment
scan_file_task_msg = """ scan_file_task_description = """
Your task is to scan a file and generate knowledge abstract for it Your task is to understand a file and generate knowledge abstract for it
code base and file path should be provided by user msg, if not, raise an error code base and file path should be provided by user msg, if not, raise an error
The workflow is: The workflow is:
1. determine if abstract of file exists in db 1. determine if abstract of file exists in db
- if exists, check if the abstract is outdated - if exists, check if the abstract is outdated
- if outdated, update the abstract
- if not outdated, call res_tool_scan_result tool with status "success" and finish the workflow - if not outdated, call res_tool_scan_result tool with status "success" and finish the workflow
2. determine the type of the file(source code, config file, binary lib/executable) 2. determine the type of the file(source code, config file, binary lib/executable)
- if the file is binary - if the file is binary
@@ -17,20 +24,290 @@ The workflow is:
- breakdown code by scopes, blocks into segments with start line and end line - breakdown code by scopes, blocks into segments with start line and end line
- describe usage and/or importance of each segment with tool upload_result_to_db - describe usage and/or importance of each segment with tool upload_result_to_db
- call res_tool_scan_result tool with status "success" and finish the workflow - call res_tool_scan_result tool with status "success" and finish the workflow
""" """
class CodeSegmentDict(TypedDict):
line_start: int
line_end: int
abstract: str
links: List[str]
@tool @tool
def res_tool_scan_result(session_id: str, status: str): async def update_code_result_to_db(
codebase: str,
path: str,
md5: str,
abstract: str,
segments: List[CodeSegmentDict]
) -> bool:
"""Update code file analysis results to database
Args:
codebase: Name of the codebase
path: File path relative to codebase root
md5: MD5 hash of the file
abstract: Overall summary of the file
segments: List of code segments, each containing:
- line_start: Starting line number
- line_end: Ending line number
- abstract: Summary of this segment
- links: List of related file links
Returns:
bool: Whether the update was successful
Raises:
ValueError: When the codebase does not exist
"""
codebase_obj = await db_operations.get_codebase(codebase)
if not codebase_obj:
raise ValueError(f"Codebase {codebase} not found")
code_segments = [
CodeSegment(
line_start=seg["line_start"],
line_end=seg["line_end"],
abstract=seg["abstract"],
links=seg["links"]
)
for seg in segments
]
code_file = CodeFile(
codebase=codebase_obj,
path=path,
md5=md5,
abstract=abstract,
segments=code_segments,
scanned=True
)
return await db_operations.save_model(code_file)
@tool
async def update_config_result_to_db(
codebase: str,
path: str,
md5: str,
abstract: str
) -> bool:
"""Update configuration file analysis results to database
Args:
codebase: Name of the codebase
path: File path relative to codebase root
md5: MD5 hash of the file
abstract: Configuration summary including main config items and usage description
Returns:
bool: Whether the update was successful
Raises:
ValueError: When the codebase does not exist
"""
codebase_obj = await db_operations.get_codebase(codebase)
if not codebase_obj:
raise ValueError(f"Codebase {codebase} not found")
config_file = ConfigFile(
codebase=codebase_obj,
path=path,
md5=md5,
abstract=abstract,
scanned=True
)
return await db_operations.save_model(config_file)
@tool
async def update_ignore_result_to_db(
codebase: str,
path: str,
md5: str
) -> bool:
"""Update ignore file information to database
Args:
codebase: Name of the codebase
path: File path relative to codebase root
md5: MD5 hash of the file
Returns:
bool: Whether the update was successful
Raises:
ValueError: When the codebase does not exist
"""
codebase_obj = await db_operations.get_codebase(codebase)
if not codebase_obj:
raise ValueError(f"Codebase {codebase} not found")
ignore_file = IgnoreFile(
codebase=codebase_obj,
path=path,
md5=md5
)
return await db_operations.save_model(ignore_file)
@tool
async def update_library_result_to_db(
codebase: str,
path: str,
abstract: str
) -> bool:
"""Update binary library file analysis results to database
Args:
codebase: Name of the codebase
path: File path relative to codebase root
abstract: Library file summary including:
- Main functionality of the library
- Dependencies
- Usage scenarios
Returns:
bool: Whether the update was successful
Raises:
ValueError: When the codebase does not exist
"""
codebase_obj = await db_operations.get_codebase(codebase)
if not codebase_obj:
raise ValueError(f"Codebase {codebase} not found")
library_file = BinaryLibrary(
codebase=codebase_obj,
path=path,
abstract=abstract
)
return await db_operations.save_model(library_file)
@tool
async def update_tool_result_to_db(
codebase: str,
path: str,
abstract: str
) -> bool:
"""Update binary tool file analysis results to database
Args:
codebase: Name of the codebase
path: File path relative to codebase root
abstract: Tool file summary including:
- Main functionality of the tool
- Command line argument descriptions
- Usage examples
Returns:
bool: Whether the update was successful
Raises:
ValueError: When the codebase does not exist
"""
codebase_obj = await db_operations.get_codebase(codebase)
if not codebase_obj:
raise ValueError(f"Codebase {codebase} not found")
tool_file = BinaryTool(
codebase=codebase_obj,
path=path,
abstract=abstract
)
return await db_operations.save_model(tool_file)
@tool
def required_response_scan_result(session_id: str, status: str) -> Dict[str, str]:
"""Return scan result response
Args:
session_id: ID of the scan session
status: Status of the scan
Returns:
Dict[str, str]: Response containing session ID and status
"""
return { return {
'session_id': session_id, 'session_id': session_id,
'status': status 'status': status
} }
@tool
async def read_file(codebase: str, path: str) -> str:
"""Read file content
Args:
codebase: Name of the codebase
path: File path relative to codebase root
Returns:
str: File content
Raises:
ValueError: When the codebase does not exist or file reading fails
"""
root_path = await operations.get_codebase_root(codebase)
if not root_path:
raise ValueError(f"Codebase {codebase} not found")
full_path = f"{root_path}/{path}"
result = read_file_content(full_path)
if result["status"] == "failure":
raise ValueError(f"Failed to read file: {result['result']}")
return result["result"]
@tool @tool
def upload_result_to_db(codebase: str, file_path: str, segments): async def is_outdated(codebase: str, path: str) -> bool:
pass """Check if file summary is outdated
Args:
codebase: Name of the codebase
path: File path relative to codebase root
Returns:
bool: True if file does not exist or MD5 does not match
"""
stored_md5 = await db_operations.get_file_md5(codebase, path)
if not stored_md5:
return True
root_path = await operations.get_codebase_root(codebase)
if not root_path:
raise ValueError(f"Codebase {codebase} not found")
full_path = f"{root_path}/{path}"
result = get_file_md5(full_path)
if result["status"] == "failure":
raise ValueError(f"Failed to get file MD5: {result['result']}")
return stored_md5 != result["result"]
@tool
async def get_abstract(codebase: str, path: str) -> Optional[str]:
"""Get file summary
Args:
codebase: Name of the codebase
path: File path relative to codebase root
Returns:
Optional[str]: File summary, None if it doesn't exist
"""
return await db_operations.get_file_document(codebase, path)
tools = [
update_code_result_to_db,
update_config_result_to_db,
update_ignore_result_to_db,
update_library_result_to_db,
update_tool_result_to_db,
required_response_scan_result,
read_file,
is_outdated,
get_abstract
]

View File

@@ -59,12 +59,13 @@ class AgentTemplate:
mcp_client = MultiServerMCPClient(self.mcp_servers) mcp_client = MultiServerMCPClient(self.mcp_servers)
tools = await mcp_client.get_tools() + self.builtin_tools tools = await mcp_client.get_tools() + self.builtin_tools
prompt_template = self.prompt_template prompt_template = self.prompt_template
prompt_template = prompt_template.replace('{session_id}', session_id)
for param in self.template_params.keys(): for param in self.template_params.keys():
prompt_template = prompt_template.replace(f'{{{param}}}', self.template_params[param]) prompt_template = prompt_template.replace(f'{{{param}}}', self.template_params[param])
prompt = PromptTemplate.from_tools(prompt_template) prompt = PromptTemplate.from_template(prompt_template)
history = MongoDBChatMessageHistory( history = MongoDBChatMessageHistory(
connection_string=os.getenv("MONGODB_CONNECTION_STRING", ""), connection_string=os.getenv("MONGO_CONNECTION_STRING", None),
session_id=session_id, session_id=session_id,
database_name="ckb", database_name="ckb",
collection_name="session_history" collection_name="session_history"
@@ -72,7 +73,7 @@ class AgentTemplate:
memory = ConversationBufferWindowMemory( memory = ConversationBufferWindowMemory(
chat_memory=history, chat_memory=history,
memory_key="chat_memory", memory_key="chat_history",
k=self.history_window_size, k=self.history_window_size,
return_messages=False, return_messages=False,
) )

View File

@@ -0,0 +1,11 @@
from agents.agent_tasks.scan_file import tools, scan_file_task_description
from agents.agent_templates import AgentTemplate
from agents.prompts import general_prompt_param_builder
scan_file_agent_template = AgentTemplate()
scan_file_agent_template.set_model("gpt-4.1-nano", "openai")
scan_file_agent_template.set_builtin_tools(tools)
scan_file_agent_template.set_template_params(general_prompt_param_builder(
agent_role="assistant",
task_description=scan_file_task_description
))

View File

@@ -16,7 +16,7 @@ class MCPReactParser(AgentOutputParser):
return self.parse(text) return self.parse(text)
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
includes_final_answer = self.Final_Answer_ACTION in text includes_final_answer = self.FINAL_ANSWER_ACTION in text
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
action_match = re.search(regex, text, re.DOTALL) action_match = re.search(regex, text, re.DOTALL)

View File

@@ -1,7 +1,6 @@
general_prompt = """ general_prompt = """
You are {agent_role} You are {agent_role}
Session Id of current task is {session_id}
Your current task is
{task_description} {task_description}
@@ -29,10 +28,11 @@ You are {agent_role}
Final Answer: the final answer to the original question Final Answer: the final answer to the original question
User message: {user_msg}
Here is the conversation history: Here is the conversation history:
{chat_history} {chat_history}
User message: {user_msg}
{agent_scratchpad} {agent_scratchpad}
""" """

View File

@@ -1,9 +0,0 @@
from langchain_core.tools import tool
@tool
def res_tool_general_response(session_id: str, response: str):
return {
'type': 'response',
'session_id': session_id,
'response': response
}

View File

@@ -1,9 +1,8 @@
from odmantic import Model from odmantic import Model, ObjectId
from db_models.embedded_models.Codebase import Codebase
class BinaryLibrary(Model): class BinaryLibrary(Model):
codebase: Codebase codebase_id: ObjectId
type: str="library"
path: str path: str
abstract: str abstract: str
md5: str

View File

@@ -1,9 +1,9 @@
from odmantic import Model from odmantic import Model, ObjectId
from db_models.embedded_models.Codebase import Codebase
class BinaryTool(Model): class BinaryTool(Model):
codebase: Codebase codebase_id: ObjectId
type: str="tool"
path: str path: str
abstract: str abstract: str
md5: str

View File

@@ -1,13 +1,12 @@
from odmantic import Model from odmantic import Model, ObjectId
from typing import List from typing import List
from db_models.embedded_models.CodeSegment import CodeSegment from db_models.embedded_models.CodeSegment import CodeSegment
from db_models.embedded_models.Codebase import Codebase
class CodeFile(Model): class CodeFile(Model):
codebase: Codebase codebase_id: ObjectId
type: str type: str="code"
path: str path: str
md5: str md5: str
abstract: str abstract: str

View File

@@ -0,0 +1,8 @@
from odmantic import Model
class Codebase(Model):
name: str
version: str
branch: str
repo: str

View File

@@ -1,11 +1,9 @@
from odmantic import Model from odmantic import Model, ObjectId
from db_models.embedded_models.Codebase import Codebase
class ConfigFile(Model): class ConfigFile(Model):
codebase: Codebase codebase_id: ObjectId
type: str type: str="config"
path: str path: str
md5: str md5: str
abstract: str abstract: str

View File

@@ -1,10 +1,9 @@
from odmantic import Model from odmantic import Model, ObjectId
from db_models.embedded_models.Codebase import Codebase
class Directory(Model): class Directory(Model):
codebase: Codebase codebase_id: ObjectId
path: str path: str
md5: str md5: str
abstract: str abstract: str

View File

@@ -1,9 +1,9 @@
from odmantic import Model from odmantic import Model, ObjectId
from typing import List from typing import List
from db_models.embedded_models.Codebase import Codebase from db_models.Codebase import Codebase
class Hotspot(Model): class Hotspot(Model):
codebase: Codebase codebase_id: ObjectId
topic: str topic: str
links: List[int] links: List[int]

View File

@@ -1,9 +1,8 @@
from odmantic import Model from odmantic import Model, ObjectId
from db_models.embedded_models.Codebase import Codebase
class IgnoreFile(Model): class IgnoreFile(Model):
codebase: Codebase codebase_id: ObjectId
type:str="ignore"
path: str path: str
md5: str md5: str

View File

@@ -0,0 +1,6 @@
from db_models.BinaryTool import BinaryTool
from db_models.BinaryLibrary import BinaryLibrary
from db_models.CodeFile import CodeFile
from db_models.ConfigFile import ConfigFile
from db_models.IgnoreFile import IgnoreFile
from db_models.Codebase import Codebase

View File

@@ -1,9 +0,0 @@
from odmantic import EmbeddedModel
class Codebase(EmbeddedModel):
name: str
version: str
branch: str
path: str
repo: str

View File

@@ -0,0 +1 @@
from db_models.embedded_models.CodeSegment import CodeSegment

View File

@@ -1,14 +1,18 @@
import os import os
from threading import Lock from threading import Lock
from pymongo import MongoClient from pymongo import MongoClient
from typing import Optional
from odmantic import AIOEngine
from motor.motor_asyncio import AsyncIOMotorClient
_client = None _client = None
_db = None _db = None
_lock = Lock() _lock = Lock()
_engine: Optional[AIOEngine] = None
def init_db(): def init_db():
global _client, _db global _client, _db, _engine
if _client is None: if _client is None:
with _lock: with _lock:
if _client is None: if _client is None:
@@ -21,6 +25,11 @@ def init_db():
if db_name not in _client.list_database_names(): if db_name not in _client.list_database_names():
_client[db_name].create_collection('session_history') _client[db_name].create_collection('session_history')
_db = _client[db_name] _db = _client[db_name]
if _engine is None:
client = AsyncIOMotorClient(uri)
_engine = AIOEngine(client=client, database=db_name)
return _db return _db
@@ -34,3 +43,16 @@ def get_client():
if _client is None: if _client is None:
init_db() init_db()
return _client return _client
def get_engine() -> AIOEngine:
if _engine is None:
raise RuntimeError("Database not initialized. Call init_db first.")
return _engine
def close_db():
global _engine
if _engine is not None:
_engine.client.close()
_engine = None

View File

@@ -0,0 +1,124 @@
from typing import Optional, Any, List
from db_models import CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool, Codebase
from utils.db_connections import get_engine
from utils.ssh_connections.ssh_operations import check_branch_exists, get_file_md5 as get_workspace_file_md5
async def get_codebase(name: str) -> Optional[Codebase]:
engine = get_engine()
branch = 'main'
if '@' in name:
name, branch = name.split('@')
code_base = await engine.find_one(
Codebase,
(Codebase.name == name) & (Codebase.branch == branch)
)
if not code_base:
code_base_main = await engine.find_one(
Codebase,
(Codebase.name == name) & (Codebase.branch == 'main')
)
if not code_base_main:
raise Exception(f"Codebase {name} not found")
branch_check = check_branch_exists(code_base_main.repo, branch)
if branch_check.get("status") == "error":
raise Exception(branch_check.get("result", "Error checking branch existence"))
if branch_check.get("status") == "success" and branch_check.get("result", False):
new_code_base = Codebase(
name=name,
version=code_base_main.version,
branch=branch,
repo=code_base_main.repo
)
await engine.save(new_code_base)
return new_code_base
else:
raise Exception(f"Branch {branch} not found for codebase {name}")
return code_base
async def get_file_md5(codebase: str, path: str) -> Optional[str]:
doc = await get_file_document(codebase, path)
if doc and doc.hasattr('md5'):
return doc.md5
else:
return None
async def get_file_document(codebase: str, path: str):
engine = get_engine()
branch = 'main'
if '@' in codebase:
codebase, branch = codebase.split('@')
cb = await get_codebase(f"{codebase}@{branch}")
if not cb:
return None
all_files = []
for model in [CodeFile, ConfigFile, BinaryLibrary, BinaryTool]:
files = await engine.find(
model,
model.path == path
)
all_files.extend(files)
if not all_files:
return None
codebase_files = [file for file in all_files if file.codebase_id == cb.id]
if codebase_files:
return codebase_files[0]
workspace_path = f"~/workspace/{codebase}@{branch}/{path}"
workspace_md5_result = get_workspace_file_md5(workspace_path)
if workspace_md5_result.get("status") == "success":
workspace_md5 = workspace_md5_result.get("result")
for file in all_files:
if hasattr(file, 'md5') and file.md5 == workspace_md5:
new_file = type(file)(
codebase_id=cb.id,
path=file.path,
**{k: v for k, v in file.dict().items()
if k not in ['id', 'codebase_id', 'path']}
)
await engine.save(new_file)
return new_file
return None
async def save_model(model: Any) -> bool:
try:
engine = get_engine()
await engine.save(model)
return True
except Exception as e:
print(f"Error saving model: {e}")
return False
async def update_branch_documents(from_codebase_id: str, to_codebase_id: str):
engine = get_engine()
for model in [CodeFile, ConfigFile, IgnoreFile, BinaryLibrary, BinaryTool]:
from_docs = await engine.find(model, model.codebase_id == from_codebase_id)
for doc in from_docs:
to_docs = await engine.find(
model,
(model.path == doc.path) & (model.codebase_id == to_codebase_id)
)
if to_docs:
for to_doc in to_docs:
await engine.delete(to_doc)
doc.codebase_id = to_codebase_id
await save_model(doc)

73
src/utils/operations.py Normal file
View File

@@ -0,0 +1,73 @@
from typing import Optional
from db_models import Codebase
from utils.db_connections.db_operations import get_codebase, save_model
from utils.ssh_connections.ssh_operations import (
clone_repo, check_branch_exists, create_and_push_branch,
checkout_branch, merge_branches, push_changes, pull_changes, is_dir_exist
)
from utils.db_connections.db_operations import update_branch_documents
async def get_codebase_root(name: str) -> Optional[str]:
branch = 'main'
if '@' in name:
name, branch = name.split('@')
exist = is_dir_exist(f'{name}@{branch}')
if exist.get('status', 'error') == 'error':
raise Exception(exist.get('result', ''))
if not exist.get('result', False):
try:
cb = await get_codebase(f'{name}@{branch}')
clone_repo(cb.repo, branch)
except Exception as e:
raise Exception(f"Codebase {name}@{branch} not found")
return f'~/workspace/{name}@{branch}'
async def create_branch(codebase: str, from_branch: str, new_branch: str):
cb = await get_codebase(f'{codebase}@{from_branch}')
branch_check = check_branch_exists(cb.repo, new_branch)
if branch_check.get("status") == "error":
raise Exception(branch_check.get("result", "Error checking branch existence"))
if branch_check.get("status") == "success" and branch_check.get("result", True):
raise Exception(f"Branch {new_branch} already exists for codebase {codebase}")
create_result = create_and_push_branch(codebase, from_branch, new_branch)
if create_result.get("status") != "success":
raise Exception(f"Failed to create branch {new_branch}: {create_result.get('result', '')}")
checkout_result = checkout_branch(codebase, from_branch)
if checkout_result.get("status") != "success":
raise Exception(f"Failed to checkout back to branch {from_branch}: {checkout_result.get('result', '')}")
new_codebase = Codebase(
name=codebase,
version=cb.version,
branch=new_branch,
repo=cb.repo
)
await save_model(new_codebase)
async def merge_branch(codebase: str, from_branch: str, to_branch: str):
from_cb = await get_codebase(f'{codebase}@{from_branch}')
to_cb = await get_codebase(f'{codebase}@{to_branch}')
if not from_cb or not to_cb:
raise Exception(f"One or both branches not found: {from_branch}, {to_branch}")
merge_result = merge_branches(codebase, from_branch, to_branch)
if merge_result.get("status") != "success":
raise Exception(f"Failed to merge branch {from_branch} into {to_branch}: {merge_result.get('result', '')}")
push_result = push_changes(codebase, to_branch)
if push_result.get("status") != "success":
raise Exception(f"Failed to push merged changes to remote: {push_result.get('result', '')}")
pull_result = pull_changes(codebase, to_branch)
if pull_result.get("status") != "success":
raise Exception(f"Failed to pull latest changes from remote: {pull_result.get('result', '')}")
await update_branch_documents(from_cb.id, to_cb.id)
return {"status": "success", "message": f"Successfully merged branch {from_branch} into {to_branch}"}

View File

@@ -1,18 +1,18 @@
import os import os
import paramiko import paramiko
from threading import Lock from threading import Lock
from typing import Tuple, Optional, List, Dict, Any
import json from utils.ssh_connections.ssh_operations import is_dir_exist, make_dir
class SSHConnectionManager: class SSHConnectionManager:
_clients = {} _clients = {}
_lock = Lock() _lock = Lock()
HOST = os.getenv('SSH_HOST', 'host.docker.internal') HOST = os.getenv('CKB_SSH_HOST', 'host.docker.internal')
USERNAME = os.getenv('SSH_USERNAME') USERNAME = os.getenv('CKB_SSH_USERNAME')
PORT = os.getenv('SSH_PORT', 22) PORT = os.getenv('CKB_SSH_PORT', 22)
PASSWORD = os.getenv('SSH_PASSWORD') PASSWORD = os.getenv('CKB_SSH_PASSWORD', None)
@classmethod @classmethod
def get_client(cls, timeout=10): def get_client(cls, timeout=10):
@@ -34,79 +34,8 @@ class SSHConnectionManager:
return cls._clients[key] return cls._clients[key]
def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]: def init_workspace():
client = SSHConnectionManager.get_client(timeout=timeout) make_dir('~/workspace')
stdin, stdout, stderr = client.exec_command(command, timeout=timeout)
exit_code = stdout.channel.recv_exit_status()
return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8')
def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
files = sftp.listdir_attr(path)
result = []
for file in files:
if not include_ignore and file.filename.startswith('.'):
continue
result.append({
'name': file.filename,
'size': file.st_size,
'mode': file.st_mode,
'mtime': file.st_mtime,
'is_dir': file.st_mode & 0o40000 != 0
})
sftp.close()
return {"status": "success", "result": result}
except Exception as e:
return {"status": "failure", "result": str(e)}
def read_file_content(path: str) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
with sftp.open(path, 'r') as f:
content = f.read().decode('utf-8')
sftp.close()
return {"status": "success", "result": content}
except Exception as e:
return {"status": "failure", "result": str(e)}
def write_file_content(path: str, content: str) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
with sftp.open(path, 'w') as f:
f.write(content)
sftp.close()
return {"status": "success", "result": None}
except Exception as e:
return {"status": "failure", "result": str(e)}
def get_file_md5(path: str) -> Dict[str, Any]:
try:
exit_code, stdout, stderr = execute_command(f"md5sum {path}")
if exit_code == 0:
md5 = stdout.split()[0]
return {"status": "success", "result": md5}
return {"status": "failure", "result": stderr}
except Exception as e:
return {"status": "failure", "result": str(e)}
def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]:
try:
sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'"
exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout)
if exit_code == 0:
return {"status": "success", "result": stdout}
return {"status": "failure", "result": stderr}
except Exception as e:
return {"status": "failure", "result": str(e)}

View File

@@ -0,0 +1,185 @@
from utils.ssh_connections import SSHConnectionManager
from typing import Tuple, Dict, Any
def execute_command(command: str, timeout: int = 30) -> Tuple[int, str, str]:
client = SSHConnectionManager.get_client(timeout=timeout)
stdin, stdout, stderr = client.exec_command(command, timeout=timeout)
exit_code = stdout.channel.recv_exit_status()
return exit_code, stdout.read().decode('utf-8'), stderr.read().decode('utf-8')
def _test(exit_code: int):
if exit_code == 0:
return {"status": "success", "result": True}
return {"status": "failure", "result": False}
def make_dir(path: str) -> Dict[str, Any]:
try:
exit_code, stdout, stderr = execute_command(f"mkdir -p {path}")
return _test(exit_code)
except Exception as e:
return {"status": "error", "result": str(e)}
def is_file_exist(path: str) -> Dict[str, Any]:
try:
exit_code, stdout, stderr = execute_command(f"test -f {path}")
return _test(exit_code)
except Exception as e:
return {"status": "error", "result": str(e)}
def is_dir_exist(path: str) -> Dict[str, Any]:
try:
exit_code, stdout, stderr = execute_command(f"test -d {path}")
return _test(exit_code)
except Exception as e:
return {"status": "error", "result": str(e)}
def list_directory(path: str, include_ignore: bool = True) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
files = sftp.listdir_attr(path)
result = []
for file in files:
if not include_ignore and file.filename.startswith('.'):
continue
result.append({
'name': file.filename,
'size': file.st_size,
'mode': file.st_mode,
'mtime': file.st_mtime,
'is_dir': file.st_mode & 0o40000 != 0
})
sftp.close()
return {"status": "success", "result": result}
except Exception as e:
return {"status": "failure", "result": str(e)}
def read_file_content(path: str) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
with sftp.open(path, 'r') as f:
content = f.read().decode('utf-8')
sftp.close()
return {"status": "success", "result": content}
except Exception as e:
return {"status": "failure", "result": str(e)}
def write_file_content(path: str, content: str) -> Dict[str, Any]:
try:
client = SSHConnectionManager.get_client()
sftp = client.open_sftp()
with sftp.open(path, 'w') as f:
f.write(content)
sftp.close()
return {"status": "success", "result": None}
except Exception as e:
return {"status": "failure", "result": str(e)}
def get_file_md5(path: str) -> Dict[str, Any]:
try:
exit_code, stdout, stderr = execute_command(f"md5sum {path}")
if exit_code == 0:
md5 = stdout.split()[0]
return {"status": "success", "result": md5}
return {"status": "failure", "result": stderr}
except Exception as e:
return {"status": "failure", "result": str(e)}
def execute_in_sandbox(command: str, timeout: int = 30) -> Dict[str, Any]:
try:
sandbox_cmd = f"docker run --rm --network none --memory=512m --cpus=1 alpine sh -c '{command}'"
exit_code, stdout, stderr = execute_command(sandbox_cmd, timeout)
if exit_code == 0:
return {"status": "success", "result": stdout}
return {"status": "failure", "result": stderr}
except Exception as e:
return {"status": "failure", "result": str(e)}
def clone_repo(url: str, branch: str = "main") -> Dict[str, Any]:
try:
name = url.split("/")[-1].replace(".git", "")+f'@{branch}'
exit_code, stdout, stderr = execute_command(f"git clone -b {branch} {url} ~/workspace/{name}", timeout=7200)
return _test(exit_code)
except Exception as e:
return {"status": "error", "result": str(e)}
def remove_codebase(codebase: str) -> Dict[str, Any]:
try:
if not '@' in codebase:
codebase = f'{codebase}@main'
exit_code, stdout, stdin = execute_command(f"rm -rf ~/workspace/{codebase}")
return _test(exit_code)
except Exception as e:
return {"status": "error", "result": str(e)}
def check_branch_exists(repo_url: str, branch: str) -> Dict[str, Any]:
try:
command = f"git ls-remote --heads {repo_url} refs/heads/{branch} | wc -l"
exit_code, stdout, stderr = execute_command(command)
if exit_code == 0:
count = int(stdout.strip())
return {"status": "success", "result": count > 0}
return {"status": "failure", "result": stderr}
except Exception as e:
return {"status": "error", "result": str(e)}
def create_and_push_branch(codebase: str, from_branch: str, new_branch: str) -> Dict[str, Any]:
try:
create_cmd = f"cd ~/workspace/{codebase}@{from_branch} && git checkout -b {new_branch} && git push origin {new_branch}"
exit_code, stdout, stderr = execute_command(create_cmd)
if exit_code != 0:
return {"status": "failure", "result": stderr}
return {"status": "success", "result": stdout}
except Exception as e:
return {"status": "error", "result": str(e)}
def checkout_branch(codebase: str, branch: str) -> Dict[str, Any]:
try:
checkout_cmd = f"cd ~/workspace/{codebase}@{branch} && git checkout {branch}"
exit_code, stdout, stderr = execute_command(checkout_cmd)
if exit_code != 0:
return {"status": "failure", "result": stderr}
return {"status": "success", "result": stdout}
except Exception as e:
return {"status": "error", "result": str(e)}
def merge_branches(codebase: str, from_branch: str, to_branch: str) -> Dict[str, Any]:
try:
merge_cmd = f"cd ~/workspace/{codebase}@{to_branch} && git fetch origin && git merge origin/{from_branch} -X theirs"
exit_code, stdout, stderr = execute_command(merge_cmd)
if exit_code != 0:
return {"status": "failure", "result": stderr}
return {"status": "success", "result": stdout}
except Exception as e:
return {"status": "error", "result": str(e)}
def push_changes(codebase: str, branch: str) -> Dict[str, Any]:
try:
push_cmd = f"cd ~/workspace/{codebase}@{branch} && git push origin {branch}"
exit_code, stdout, stderr = execute_command(push_cmd)
if exit_code != 0:
return {"status": "failure", "result": stderr}
return {"status": "success", "result": stdout}
except Exception as e:
return {"status": "error", "result": str(e)}
def pull_changes(codebase: str, branch: str) -> Dict[str, Any]:
try:
pull_cmd = f"cd ~/workspace/{codebase}@{branch} && git pull origin {branch}"
exit_code, stdout, stderr = execute_command(pull_cmd)
if exit_code != 0:
return {"status": "failure", "result": stderr}
return {"status": "success", "result": stdout}
except Exception as e:
return {"status": "error", "result": str(e)}