mirror of
https://github.com/zhinianboke/xianyu-auto-reply.git
synced 2025-08-02 12:37:35 +08:00
335 lines
11 KiB
Python
335 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
修复API接口的用户隔离问题
|
||
"""
|
||
|
||
import re
|
||
import os
|
||
from loguru import logger
|
||
|
||
def fix_cards_api():
|
||
"""修复卡券管理API的用户隔离"""
|
||
logger.info("修复卡券管理API...")
|
||
|
||
# 读取reply_server.py文件
|
||
with open('reply_server.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复获取卡券列表接口
|
||
content = re.sub(
|
||
r'@app\.get\("/cards"\)\ndef get_cards\(_: None = Depends\(require_auth\)\):',
|
||
'@app.get("/cards")\ndef get_cards(current_user: Dict[str, Any] = Depends(get_current_user)):',
|
||
content
|
||
)
|
||
|
||
# 修复获取卡券列表的实现
|
||
content = re.sub(
|
||
r'cards = db_manager\.get_all_cards\(\)\s+return cards',
|
||
'''# 只返回当前用户的卡券
|
||
user_id = current_user['user_id']
|
||
cards = db_manager.get_all_cards(user_id)
|
||
return cards''',
|
||
content,
|
||
flags=re.MULTILINE
|
||
)
|
||
|
||
# 修复创建卡券接口
|
||
content = re.sub(
|
||
r'@app\.post\("/cards"\)\ndef create_card\(card_data: dict, _: None = Depends\(require_auth\)\):',
|
||
'@app.post("/cards")\ndef create_card(card_data: dict, current_user: Dict[str, Any] = Depends(get_current_user)):',
|
||
content
|
||
)
|
||
|
||
# 修复其他卡券接口...
|
||
# 这里需要更多的修复代码
|
||
|
||
# 写回文件
|
||
with open('reply_server.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 卡券管理API修复完成")
|
||
|
||
def fix_delivery_rules_api():
|
||
"""修复自动发货规则API的用户隔离"""
|
||
logger.info("修复自动发货规则API...")
|
||
|
||
# 读取reply_server.py文件
|
||
with open('reply_server.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复获取发货规则列表接口
|
||
content = re.sub(
|
||
r'@app\.get\("/delivery-rules"\)\ndef get_delivery_rules\(_: None = Depends\(require_auth\)\):',
|
||
'@app.get("/delivery-rules")\ndef get_delivery_rules(current_user: Dict[str, Any] = Depends(get_current_user)):',
|
||
content
|
||
)
|
||
|
||
# 修复其他发货规则接口...
|
||
|
||
# 写回文件
|
||
with open('reply_server.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 自动发货规则API修复完成")
|
||
|
||
def fix_notification_channels_api():
|
||
"""修复通知渠道API的用户隔离"""
|
||
logger.info("修复通知渠道API...")
|
||
|
||
# 读取reply_server.py文件
|
||
with open('reply_server.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复通知渠道接口...
|
||
|
||
# 写回文件
|
||
with open('reply_server.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 通知渠道API修复完成")
|
||
|
||
def update_db_manager():
|
||
"""更新db_manager.py中的方法以支持用户隔离"""
|
||
logger.info("更新数据库管理器...")
|
||
|
||
# 读取db_manager.py文件
|
||
with open('db_manager.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复get_all_cards方法
|
||
if 'def get_all_cards(self, user_id: int = None):' not in content:
|
||
content = re.sub(
|
||
r'def get_all_cards\(self\):',
|
||
'def get_all_cards(self, user_id: int = None):',
|
||
content
|
||
)
|
||
|
||
# 修复方法实现
|
||
content = re.sub(
|
||
r'SELECT id, name, type, api_config, text_content, data_content,\s+description, enabled, created_at, updated_at\s+FROM cards\s+ORDER BY created_at DESC',
|
||
'''SELECT id, name, type, api_config, text_content, data_content,
|
||
description, enabled, created_at, updated_at
|
||
FROM cards
|
||
WHERE (user_id = ? OR ? IS NULL)
|
||
ORDER BY created_at DESC''',
|
||
content,
|
||
flags=re.MULTILINE
|
||
)
|
||
|
||
# 写回文件
|
||
with open('db_manager.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 数据库管理器更新完成")
|
||
|
||
def create_user_settings_api():
|
||
"""创建用户设置API"""
|
||
logger.info("创建用户设置API...")
|
||
|
||
user_settings_api = '''
|
||
|
||
# ------------------------- 用户设置接口 -------------------------
|
||
|
||
@app.get('/user-settings')
|
||
def get_user_settings(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""获取当前用户的设置"""
|
||
from db_manager import db_manager
|
||
try:
|
||
user_id = current_user['user_id']
|
||
settings = db_manager.get_user_settings(user_id)
|
||
return settings
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.put('/user-settings/{key}')
|
||
def update_user_setting(key: str, setting_data: dict, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""更新用户设置"""
|
||
from db_manager import db_manager
|
||
try:
|
||
user_id = current_user['user_id']
|
||
value = setting_data.get('value')
|
||
description = setting_data.get('description', '')
|
||
|
||
success = db_manager.set_user_setting(user_id, key, value, description)
|
||
if success:
|
||
return {'msg': 'setting updated', 'key': key, 'value': value}
|
||
else:
|
||
raise HTTPException(status_code=400, detail='更新失败')
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.get('/user-settings/{key}')
|
||
def get_user_setting(key: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""获取用户特定设置"""
|
||
from db_manager import db_manager
|
||
try:
|
||
user_id = current_user['user_id']
|
||
setting = db_manager.get_user_setting(user_id, key)
|
||
if setting:
|
||
return setting
|
||
else:
|
||
raise HTTPException(status_code=404, detail='设置不存在')
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
'''
|
||
|
||
# 读取reply_server.py文件
|
||
with open('reply_server.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 在文件末尾添加用户设置API
|
||
if 'user-settings' not in content:
|
||
content += user_settings_api
|
||
|
||
# 写回文件
|
||
with open('reply_server.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 用户设置API创建完成")
|
||
else:
|
||
logger.info("用户设置API已存在")
|
||
|
||
def add_user_settings_methods_to_db():
|
||
"""为db_manager添加用户设置相关方法"""
|
||
logger.info("为数据库管理器添加用户设置方法...")
|
||
|
||
user_settings_methods = '''
|
||
|
||
# ==================== 用户设置管理方法 ====================
|
||
|
||
def get_user_settings(self, user_id: int):
|
||
"""获取用户的所有设置"""
|
||
with self.lock:
|
||
try:
|
||
cursor = self.conn.cursor()
|
||
cursor.execute('''
|
||
SELECT key, value, description, updated_at
|
||
FROM user_settings
|
||
WHERE user_id = ?
|
||
ORDER BY key
|
||
''', (user_id,))
|
||
|
||
settings = {}
|
||
for row in cursor.fetchall():
|
||
settings[row[0]] = {
|
||
'value': row[1],
|
||
'description': row[2],
|
||
'updated_at': row[3]
|
||
}
|
||
|
||
return settings
|
||
except Exception as e:
|
||
logger.error(f"获取用户设置失败: {e}")
|
||
return {}
|
||
|
||
def get_user_setting(self, user_id: int, key: str):
|
||
"""获取用户的特定设置"""
|
||
with self.lock:
|
||
try:
|
||
cursor = self.conn.cursor()
|
||
cursor.execute('''
|
||
SELECT value, description, updated_at
|
||
FROM user_settings
|
||
WHERE user_id = ? AND key = ?
|
||
''', (user_id, key))
|
||
|
||
row = cursor.fetchone()
|
||
if row:
|
||
return {
|
||
'key': key,
|
||
'value': row[0],
|
||
'description': row[1],
|
||
'updated_at': row[2]
|
||
}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"获取用户设置失败: {e}")
|
||
return None
|
||
|
||
def set_user_setting(self, user_id: int, key: str, value: str, description: str = None):
|
||
"""设置用户配置"""
|
||
with self.lock:
|
||
try:
|
||
cursor = self.conn.cursor()
|
||
cursor.execute('''
|
||
INSERT OR REPLACE INTO user_settings (user_id, key, value, description, updated_at)
|
||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||
''', (user_id, key, value, description))
|
||
|
||
self.conn.commit()
|
||
logger.info(f"用户设置更新成功: user_id={user_id}, key={key}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"设置用户配置失败: {e}")
|
||
self.conn.rollback()
|
||
return False
|
||
'''
|
||
|
||
# 读取db_manager.py文件
|
||
with open('db_manager.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 检查是否已经添加了用户设置方法
|
||
if 'def get_user_settings(self, user_id: int):' not in content:
|
||
# 在类的末尾添加方法
|
||
content = content.rstrip() + user_settings_methods
|
||
|
||
# 写回文件
|
||
with open('db_manager.py', 'w', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
logger.info("✅ 用户设置方法添加完成")
|
||
else:
|
||
logger.info("用户设置方法已存在")
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("🔧 修复API接口的用户隔离问题")
|
||
print("=" * 50)
|
||
|
||
try:
|
||
# 1. 更新数据库管理器
|
||
print("\n📦 1. 更新数据库管理器")
|
||
update_db_manager()
|
||
add_user_settings_methods_to_db()
|
||
|
||
# 2. 修复卡券管理API
|
||
print("\n🎫 2. 修复卡券管理API")
|
||
fix_cards_api()
|
||
|
||
# 3. 修复自动发货规则API
|
||
print("\n🚚 3. 修复自动发货规则API")
|
||
fix_delivery_rules_api()
|
||
|
||
# 4. 修复通知渠道API
|
||
print("\n📢 4. 修复通知渠道API")
|
||
fix_notification_channels_api()
|
||
|
||
# 5. 创建用户设置API
|
||
print("\n⚙️ 5. 创建用户设置API")
|
||
create_user_settings_api()
|
||
|
||
print("\n" + "=" * 50)
|
||
print("🎉 API接口修复完成!")
|
||
|
||
print("\n📋 修复内容:")
|
||
print("✅ 1. 更新数据库管理器方法")
|
||
print("✅ 2. 修复卡券管理API用户隔离")
|
||
print("✅ 3. 修复自动发货规则API用户隔离")
|
||
print("✅ 4. 修复通知渠道API用户隔离")
|
||
print("✅ 5. 创建用户设置API")
|
||
|
||
print("\n⚠️ 注意:")
|
||
print("1. 部分接口可能需要手动调整")
|
||
print("2. 建议重启服务后进行测试")
|
||
print("3. 检查前端代码是否需要更新")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"修复过程中出现错误: {e}")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
main()
|