Add get_all_tasks() endpoint
This commit is contained in:
@@ -94,6 +94,22 @@ def create_task(
|
|||||||
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
|
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastapi import Query
|
||||||
|
|
||||||
|
@router.get("/tasks", response_model=TaskQueryResponse, summary="Get all tasks")
|
||||||
|
def get_all_tasks(request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
|
||||||
|
request_id = base.get_task_id(request)
|
||||||
|
tasks, total = sm.state.get_all_tasks(page, page_size)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"tasks": tasks,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
}
|
||||||
|
return utils.get_response(200, response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"
|
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"
|
||||||
|
|||||||
@@ -15,12 +15,23 @@ class BaseState(ABC):
|
|||||||
def get_task(self, task_id: str):
|
def get_task(self, task_id: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_tasks(self, page: int, page_size: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Memory state management
|
# Memory state management
|
||||||
class MemoryState(BaseState):
|
class MemoryState(BaseState):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tasks = {}
|
self._tasks = {}
|
||||||
|
|
||||||
|
def get_all_tasks(self, page: int, page_size: int):
|
||||||
|
start = (page - 1) * page_size
|
||||||
|
end = start + page_size
|
||||||
|
tasks = list(self._tasks.values())
|
||||||
|
total = len(tasks)
|
||||||
|
return tasks[start:end], total
|
||||||
|
|
||||||
def update_task(
|
def update_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -33,6 +44,7 @@ class MemoryState(BaseState):
|
|||||||
progress = 100
|
progress = 100
|
||||||
|
|
||||||
self._tasks[task_id] = {
|
self._tasks[task_id] = {
|
||||||
|
"task_id": task_id,
|
||||||
"state": state,
|
"state": state,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -53,6 +65,28 @@ class RedisState(BaseState):
|
|||||||
|
|
||||||
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
|
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
|
||||||
|
|
||||||
|
def get_all_tasks(self, page: int, page_size: int):
|
||||||
|
start = (page - 1) * page_size
|
||||||
|
end = start + page_size
|
||||||
|
tasks = []
|
||||||
|
cursor = 0
|
||||||
|
total = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = self._redis.scan(cursor, count=page_size)
|
||||||
|
total += len(keys)
|
||||||
|
if total > start:
|
||||||
|
for key in keys[max(0, start - total):end - total]:
|
||||||
|
task_data = self._redis.hgetall(key)
|
||||||
|
task = {
|
||||||
|
k.decode("utf-8"): self._convert_to_original_type(v) for k, v in task_data.items()
|
||||||
|
}
|
||||||
|
tasks.append(task)
|
||||||
|
if len(tasks) >= page_size:
|
||||||
|
break
|
||||||
|
if cursor == 0 or len(tasks) >= page_size:
|
||||||
|
break
|
||||||
|
return tasks, total
|
||||||
|
|
||||||
def update_task(
|
def update_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -65,6 +99,7 @@ class RedisState(BaseState):
|
|||||||
progress = 100
|
progress = 100
|
||||||
|
|
||||||
fields = {
|
fields = {
|
||||||
|
"task_id": task_id,
|
||||||
"state": state,
|
"state": state,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user