Format project code

This commit is contained in:
yyhhyyyyyy
2024-07-24 14:59:06 +08:00
parent bbd4e94941
commit 905841965a
18 changed files with 410 additions and 214 deletions

View File

@@ -6,7 +6,6 @@ from app.models import const
# Base class for state management
class BaseState(ABC):
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
@@ -18,11 +17,16 @@ class BaseState(ABC):
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -43,12 +47,18 @@ class MemoryState(BaseState):
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0, password=None):
def __init__(self, host="localhost", port=6379, db=0, password=None):
import redis
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
def update_task(
self,
task_id: str,
state: int = const.TASK_STATE_PROCESSING,
progress: int = 0,
**kwargs,
):
progress = int(progress)
if progress > 100:
progress = 100
@@ -67,7 +77,10 @@ class RedisState(BaseState):
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
task = {
key.decode("utf-8"): self._convert_to_original_type(value)
for key, value in task_data.items()
}
return task
def delete_task(self, task_id: str):
@@ -79,7 +92,7 @@ class RedisState(BaseState):
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')
value_str = value.decode("utf-8")
try:
# try to convert byte string array to list
@@ -100,4 +113,10 @@ _redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password) if _enable_redis else MemoryState()
state = (
RedisState(
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password
)
if _enable_redis
else MemoryState()
)