Update ai_reply_engine.py

This commit is contained in:
zhinianboke 2025-08-16 12:53:49 +08:00
parent e49f45ba6e
commit dd319014e7

View File

@ -7,6 +7,7 @@ import os
import json import json
import time import time
import sqlite3 import sqlite3
import requests
from typing import List, Dict, Optional from typing import List, Dict, Optional
from loguru import logger from loguru import logger
from openai import OpenAI from openai import OpenAI
@ -60,17 +61,108 @@ class AIReplyEngine:
return None return None
try: try:
logger.info(f"创建OpenAI客户端 {cookie_id}: base_url={settings['base_url']}, api_key={'***' + settings['api_key'][-4:] if settings['api_key'] else 'None'}")
self.clients[cookie_id] = OpenAI( self.clients[cookie_id] = OpenAI(
api_key=settings['api_key'], api_key=settings['api_key'],
base_url=settings['base_url'] base_url=settings['base_url']
) )
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端") logger.info(f"为账号 {cookie_id} 创建OpenAI客户端成功实际base_url: {self.clients[cookie_id].base_url}")
except Exception as e: except Exception as e:
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}") logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
return None return None
return self.clients[cookie_id] return self.clients[cookie_id]
def _is_dashscope_api(self, settings: dict) -> bool:
"""判断是否为DashScope API - 只有选择自定义模型时才使用"""
model_name = settings.get('model_name', '')
base_url = settings.get('base_url', '')
# 只有当模型名称为"custom"或"自定义"时才使用DashScope API格式
# 其他情况都使用OpenAI兼容格式
is_custom_model = model_name.lower() in ['custom', '自定义', 'dashscope', 'qwen-custom']
is_dashscope_url = 'dashscope.aliyuncs.com' in base_url
logger.info(f"API类型判断: model_name={model_name}, is_custom_model={is_custom_model}, is_dashscope_url={is_dashscope_url}")
return is_custom_model and is_dashscope_url
def _call_dashscope_api(self, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
"""调用DashScope API"""
# 提取app_id从base_url
base_url = settings['base_url']
if '/apps/' in base_url:
app_id = base_url.split('/apps/')[-1].split('/')[0]
else:
raise ValueError("DashScope API URL中未找到app_id")
# 构建请求URL
url = f"https://dashscope.aliyuncs.com/api/v1/apps/{app_id}/completion"
# 构建提示词将messages合并为单个prompt
system_content = ""
user_content = ""
for msg in messages:
if msg['role'] == 'system':
system_content = msg['content']
elif msg['role'] == 'user':
user_content = msg['content']
# 构建更清晰的prompt格式
if system_content and user_content:
prompt = f"{system_content}\n\n用户问题:{user_content}\n\n请直接回答用户的问题:"
elif user_content:
prompt = user_content
else:
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# 构建请求数据
data = {
"input": {
"prompt": prompt
},
"parameters": {
"max_tokens": max_tokens,
"temperature": temperature
},
"debug": {}
}
headers = {
"Authorization": f"Bearer {settings['api_key']}",
"Content-Type": "application/json"
}
logger.info(f"DashScope API请求: {url}")
logger.info(f"发送的prompt: {prompt}")
logger.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
response = requests.post(url, headers=headers, json=data, timeout=30)
if response.status_code != 200:
logger.error(f"DashScope API请求失败: {response.status_code} - {response.text}")
raise Exception(f"DashScope API请求失败: {response.status_code} - {response.text}")
result = response.json()
logger.debug(f"DashScope API响应: {json.dumps(result, ensure_ascii=False)}")
# 提取回复内容
if 'output' in result and 'text' in result['output']:
return result['output']['text'].strip()
else:
raise Exception(f"DashScope API响应格式错误: {result}")
def _call_openai_api(self, client: OpenAI, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
"""调用OpenAI兼容API"""
response = client.chat.completions.create(
model=settings['model_name'],
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content.strip()
def is_ai_enabled(self, cookie_id: str) -> bool: def is_ai_enabled(self, cookie_id: str) -> bool:
"""检查指定账号是否启用AI回复""" """检查指定账号是否启用AI回复"""
settings = db_manager.get_ai_reply_settings(cookie_id) settings = db_manager.get_ai_reply_settings(cookie_id)
@ -78,26 +170,35 @@ class AIReplyEngine:
def detect_intent(self, message: str, cookie_id: str) -> str: def detect_intent(self, message: str, cookie_id: str) -> str:
"""检测用户消息意图""" """检测用户消息意图"""
client = self.get_client(cookie_id)
if not client:
return 'default'
try: try:
settings = db_manager.get_ai_reply_settings(cookie_id) settings = db_manager.get_ai_reply_settings(cookie_id)
if not settings['ai_enabled'] or not settings['api_key']:
return 'default'
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {} custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify']) classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
response = client.chat.completions.create( # 打印调试信息
model=settings['model_name'], logger.info(f"AI设置调试 {cookie_id}: base_url={settings['base_url']}, model={settings['model_name']}")
messages=[
{"role": "system", "content": classify_prompt},
{"role": "user", "content": message}
],
max_tokens=10,
temperature=0.1
)
intent = response.choices[0].message.content.strip().lower() messages = [
{"role": "system", "content": classify_prompt},
{"role": "user", "content": message}
]
# 根据API类型选择调用方式
if self._is_dashscope_api(settings):
logger.info(f"使用DashScope API进行意图检测")
response_text = self._call_dashscope_api(settings, messages, max_tokens=10, temperature=0.1)
else:
logger.info(f"使用OpenAI兼容API进行意图检测")
client = self.get_client(cookie_id)
if not client:
return 'default'
logger.info(f"OpenAI客户端base_url: {client.base_url}")
response_text = self._call_openai_api(client, settings, messages, max_tokens=10, temperature=0.1)
intent = response_text.lower()
if intent in ['price', 'tech', 'default']: if intent in ['price', 'tech', 'default']:
return intent return intent
else: else:
@ -105,6 +206,11 @@ class AIReplyEngine:
except Exception as e: except Exception as e:
logger.error(f"意图检测失败 {cookie_id}: {e}") logger.error(f"意图检测失败 {cookie_id}: {e}")
# 打印更详细的错误信息
if hasattr(e, 'response') and hasattr(e.response, 'url'):
logger.error(f"请求URL: {e.response.url}")
if hasattr(e, 'request') and hasattr(e.request, 'url'):
logger.error(f"请求URL: {e.request.url}")
return 'default' return 'default'
def generate_reply(self, message: str, item_info: dict, chat_id: str, def generate_reply(self, message: str, item_info: dict, chat_id: str,
@ -113,10 +219,6 @@ class AIReplyEngine:
if not self.is_ai_enabled(cookie_id): if not self.is_ai_enabled(cookie_id):
return None return None
client = self.get_client(cookie_id)
if not client:
return None
try: try:
# 1. 获取AI回复设置 # 1. 获取AI回复设置
settings = db_manager.get_ai_reply_settings(cookie_id) settings = db_manager.get_ai_reply_settings(cookie_id)
@ -177,17 +279,21 @@ class AIReplyEngine:
请根据以上信息生成回复""" 请根据以上信息生成回复"""
# 10. 调用AI生成回复 # 10. 调用AI生成回复
response = client.chat.completions.create( messages = [
model=settings['model_name'], {"role": "system", "content": system_prompt},
messages=[ {"role": "user", "content": user_prompt}
{"role": "system", "content": system_prompt}, ]
{"role": "user", "content": user_prompt}
],
max_tokens=100,
temperature=0.7
)
reply = response.choices[0].message.content.strip() # 根据API类型选择调用方式
if self._is_dashscope_api(settings):
logger.info(f"使用DashScope API生成回复")
reply = self._call_dashscope_api(settings, messages, max_tokens=100, temperature=0.7)
else:
logger.info(f"使用OpenAI兼容API生成回复")
client = self.get_client(cookie_id)
if not client:
return None
reply = self._call_openai_api(client, settings, messages, max_tokens=100, temperature=0.7)
# 11. 保存对话记录 # 11. 保存对话记录
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent) self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
@ -202,6 +308,11 @@ class AIReplyEngine:
except Exception as e: except Exception as e:
logger.error(f"AI回复生成失败 {cookie_id}: {e}") logger.error(f"AI回复生成失败 {cookie_id}: {e}")
# 打印更详细的错误信息
if hasattr(e, 'response') and hasattr(e.response, 'url'):
logger.error(f"请求URL: {e.response.url}")
if hasattr(e, 'request') and hasattr(e.request, 'url'):
logger.error(f"请求URL: {e.request.url}")
return None return None
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]: def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]: