feat: add redis support for task state management
This commit is contained in:
@@ -1,35 +1,96 @@
|
||||
# State Management
|
||||
# This module is responsible for managing the state of the application.
|
||||
import math
|
||||
|
||||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
|
||||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
|
||||
|
||||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
|
||||
# If your application is single-node, you can use memory to store the state.
|
||||
|
||||
import ast
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
import redis
|
||||
from app.config import config
|
||||
from app.models import const
|
||||
from app.utils import utils
|
||||
|
||||
_tasks = {}
|
||||
|
||||
|
||||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
"""
|
||||
Set the state of the task.
|
||||
"""
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
# Base class for state management
|
||||
class BaseState(ABC):
|
||||
|
||||
_tasks[task_id] = {
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
**kwargs,
|
||||
}
|
||||
@abstractmethod
|
||||
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
|
||||
pass
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Get the state of the task.
|
||||
"""
|
||||
return _tasks.get(task_id, None)
|
||||
@abstractmethod
|
||||
def get_task(self, task_id: str):
|
||||
pass
|
||||
|
||||
|
||||
# Memory state management
|
||||
class MemoryState(BaseState):
|
||||
|
||||
def __init__(self):
|
||||
self._tasks = {}
|
||||
|
||||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
|
||||
self._tasks[task_id] = {
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def get_task(self, task_id: str):
|
||||
return self._tasks.get(task_id, None)
|
||||
|
||||
|
||||
# Redis state management
|
||||
class RedisState(BaseState):
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0):
|
||||
self._redis = redis.StrictRedis(host=host, port=port, db=db)
|
||||
|
||||
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
|
||||
fields = {
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
for field, value in fields.items():
|
||||
self._redis.hset(task_id, field, str(value))
|
||||
|
||||
def get_task(self, task_id: str):
|
||||
task_data = self._redis.hgetall(task_id)
|
||||
if not task_data:
|
||||
return None
|
||||
|
||||
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_original_type(value):
|
||||
"""
|
||||
Convert the value from byte string to its original data type.
|
||||
You can extend this method to handle other data types as needed.
|
||||
"""
|
||||
value_str = value.decode('utf-8')
|
||||
|
||||
try:
|
||||
# try to convert byte string array to list
|
||||
return ast.literal_eval(value_str)
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if value_str.isdigit():
|
||||
return int(value_str)
|
||||
# Add more conversions here if needed
|
||||
return value_str
|
||||
|
||||
|
||||
# Global state
|
||||
_enable_redis = config.app.get("enable_redis", False)
|
||||
_redis_host = config.app.get("redis_host", "localhost")
|
||||
_redis_port = config.app.get("redis_port", 6379)
|
||||
_redis_db = config.app.get("redis_db", 0)
|
||||
|
||||
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()
|
||||
|
||||
Reference in New Issue
Block a user