Files
assist/src/core/database.py

122 lines
4.0 KiB
Python
Raw Normal View History

2025-09-06 21:06:18 +08:00
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool
from contextlib import contextmanager
from typing import Generator
import logging
from .models import Base
from .cache_manager import cache_manager, cache_query
2025-09-06 21:06:18 +08:00
from ..config.config import Config
logger = logging.getLogger(__name__)
class DatabaseManager:
"""数据库管理器"""
def __init__(self):
self.engine = None
self.SessionLocal = None
self._initialize_database()
def _initialize_database(self):
"""初始化数据库连接"""
try:
db_config = Config.get_database_config()
# 根据数据库类型选择不同的连接参数
if "mysql" in db_config["url"]:
# MySQL配置 - 优化连接池
2025-09-06 21:06:18 +08:00
self.engine = create_engine(
db_config["url"],
echo=db_config["echo"],
pool_size=20, # 增加连接池大小
max_overflow=30, # 增加溢出连接数
2025-09-06 21:06:18 +08:00
pool_pre_ping=True,
pool_recycle=1800, # 减少回收时间
pool_timeout=10, # 连接超时
connect_args={
"charset": "utf8mb4",
"autocommit": False
}
2025-09-06 21:06:18 +08:00
)
else:
# SQLite配置 - 优化性能
2025-09-06 21:06:18 +08:00
self.engine = create_engine(
db_config["url"],
echo=db_config["echo"],
poolclass=StaticPool,
connect_args={
"check_same_thread": False,
"timeout": 20, # 连接超时
"isolation_level": None # 自动提交模式
}
2025-09-06 21:06:18 +08:00
)
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
# 创建所有表
Base.metadata.create_all(bind=self.engine)
logger.info("数据库初始化成功")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
raise
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""获取数据库会话的上下文管理器"""
session = self.SessionLocal()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作失败: {e}")
raise
finally:
session.close()
def get_session_direct(self) -> Session:
"""直接获取数据库会话"""
return self.SessionLocal()
def close_session(self, session: Session):
"""关闭数据库会话"""
if session:
session.close()
def test_connection(self) -> bool:
"""测试数据库连接"""
try:
with self.get_session() as session:
session.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"数据库连接测试失败: {e}")
return False
@cache_query(ttl=60) # 缓存1分钟
def get_cached_query(self, query_key: str, query_func, *args, **kwargs):
"""执行带缓存的查询"""
return query_func(*args, **kwargs)
def invalidate_cache_pattern(self, pattern: str):
"""根据模式清除缓存"""
try:
cache_manager.delete(pattern)
logger.info(f"缓存已清除: {pattern}")
except Exception as e:
logger.error(f"清除缓存失败: {e}")
def get_cache_stats(self):
"""获取缓存统计信息"""
return cache_manager.get_stats()
2025-09-06 21:06:18 +08:00
# 全局数据库管理器实例
db_manager = DatabaseManager()