From dd319014e7bc706a45be7b25b574d955a16b9ba3 Mon Sep 17 00:00:00 2001 From: zhinianboke <115088296+zhinianboke@users.noreply.github.com> Date: Sat, 16 Aug 2025 12:53:49 +0800 Subject: [PATCH] Update ai_reply_engine.py --- ai_reply_engine.py | 179 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 145 insertions(+), 34 deletions(-) diff --git a/ai_reply_engine.py b/ai_reply_engine.py index bbf5f56..f3a9d2e 100644 --- a/ai_reply_engine.py +++ b/ai_reply_engine.py @@ -7,6 +7,7 @@ import os import json import time import sqlite3 +import requests from typing import List, Dict, Optional from loguru import logger from openai import OpenAI @@ -60,17 +61,108 @@ class AIReplyEngine: return None try: + logger.info(f"创建OpenAI客户端 {cookie_id}: base_url={settings['base_url']}, api_key={'***' + settings['api_key'][-4:] if settings['api_key'] else 'None'}") self.clients[cookie_id] = OpenAI( api_key=settings['api_key'], base_url=settings['base_url'] ) - logger.info(f"为账号 {cookie_id} 创建OpenAI客户端") + logger.info(f"为账号 {cookie_id} 创建OpenAI客户端成功,实际base_url: {self.clients[cookie_id].base_url}") except Exception as e: logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}") return None return self.clients[cookie_id] - + + def _is_dashscope_api(self, settings: dict) -> bool: + """判断是否为DashScope API - 只有选择自定义模型时才使用""" + model_name = settings.get('model_name', '') + base_url = settings.get('base_url', '') + + # 只有当模型名称为"custom"或"自定义"时,才使用DashScope API格式 + # 其他情况都使用OpenAI兼容格式 + is_custom_model = model_name.lower() in ['custom', '自定义', 'dashscope', 'qwen-custom'] + is_dashscope_url = 'dashscope.aliyuncs.com' in base_url + + logger.info(f"API类型判断: model_name={model_name}, is_custom_model={is_custom_model}, is_dashscope_url={is_dashscope_url}") + + return is_custom_model and is_dashscope_url + + def _call_dashscope_api(self, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str: + """调用DashScope API""" + # 提取app_id从base_url + base_url = settings['base_url'] + if '/apps/' in base_url: + app_id = base_url.split('/apps/')[-1].split('/')[0] + else: + raise ValueError("DashScope API URL中未找到app_id") + + # 构建请求URL + url = f"https://dashscope.aliyuncs.com/api/v1/apps/{app_id}/completion" + + # 构建提示词(将messages合并为单个prompt) + system_content = "" + user_content = "" + + for msg in messages: + if msg['role'] == 'system': + system_content = msg['content'] + elif msg['role'] == 'user': + user_content = msg['content'] + + # 构建更清晰的prompt格式 + if system_content and user_content: + prompt = f"{system_content}\n\n用户问题:{user_content}\n\n请直接回答用户的问题:" + elif user_content: + prompt = user_content + else: + prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + + # 构建请求数据 + data = { + "input": { + "prompt": prompt + }, + "parameters": { + "max_tokens": max_tokens, + "temperature": temperature + }, + "debug": {} + } + + headers = { + "Authorization": f"Bearer {settings['api_key']}", + "Content-Type": "application/json" + } + + logger.info(f"DashScope API请求: {url}") + logger.info(f"发送的prompt: {prompt}") + logger.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}") + + response = requests.post(url, headers=headers, json=data, timeout=30) + + if response.status_code != 200: + logger.error(f"DashScope API请求失败: {response.status_code} - {response.text}") + raise Exception(f"DashScope API请求失败: {response.status_code} - {response.text}") + + result = response.json() + logger.debug(f"DashScope API响应: {json.dumps(result, ensure_ascii=False)}") + + # 提取回复内容 + if 'output' in result and 'text' in result['output']: + return result['output']['text'].strip() + else: + raise Exception(f"DashScope API响应格式错误: {result}") + + def _call_openai_api(self, client: OpenAI, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str: + """调用OpenAI兼容API""" + response = client.chat.completions.create( + model=settings['model_name'], + messages=messages, + max_tokens=max_tokens, + temperature=temperature + ) + return response.choices[0].message.content.strip() + def is_ai_enabled(self, cookie_id: str) -> bool: """检查指定账号是否启用AI回复""" settings = db_manager.get_ai_reply_settings(cookie_id) @@ -78,45 +170,55 @@ class AIReplyEngine: def detect_intent(self, message: str, cookie_id: str) -> str: """检测用户消息意图""" - client = self.get_client(cookie_id) - if not client: - return 'default' - try: settings = db_manager.get_ai_reply_settings(cookie_id) + if not settings['ai_enabled'] or not settings['api_key']: + return 'default' + custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {} classify_prompt = custom_prompts.get('classify', self.default_prompts['classify']) - - response = client.chat.completions.create( - model=settings['model_name'], - messages=[ - {"role": "system", "content": classify_prompt}, - {"role": "user", "content": message} - ], - max_tokens=10, - temperature=0.1 - ) - - intent = response.choices[0].message.content.strip().lower() + + # 打印调试信息 + logger.info(f"AI设置调试 {cookie_id}: base_url={settings['base_url']}, model={settings['model_name']}") + + messages = [ + {"role": "system", "content": classify_prompt}, + {"role": "user", "content": message} + ] + + # 根据API类型选择调用方式 + if self._is_dashscope_api(settings): + logger.info(f"使用DashScope API进行意图检测") + response_text = self._call_dashscope_api(settings, messages, max_tokens=10, temperature=0.1) + else: + logger.info(f"使用OpenAI兼容API进行意图检测") + client = self.get_client(cookie_id) + if not client: + return 'default' + logger.info(f"OpenAI客户端base_url: {client.base_url}") + response_text = self._call_openai_api(client, settings, messages, max_tokens=10, temperature=0.1) + + intent = response_text.lower() if intent in ['price', 'tech', 'default']: return intent else: return 'default' - + except Exception as e: logger.error(f"意图检测失败 {cookie_id}: {e}") + # 打印更详细的错误信息 + if hasattr(e, 'response') and hasattr(e.response, 'url'): + logger.error(f"请求URL: {e.response.url}") + if hasattr(e, 'request') and hasattr(e.request, 'url'): + logger.error(f"请求URL: {e.request.url}") return 'default' - def generate_reply(self, message: str, item_info: dict, chat_id: str, + def generate_reply(self, message: str, item_info: dict, chat_id: str, cookie_id: str, user_id: str, item_id: str) -> Optional[str]: """生成AI回复""" if not self.is_ai_enabled(cookie_id): return None - client = self.get_client(cookie_id) - if not client: - return None - try: # 1. 获取AI回复设置 settings = db_manager.get_ai_reply_settings(cookie_id) @@ -177,17 +279,21 @@ class AIReplyEngine: 请根据以上信息生成回复:""" # 10. 调用AI生成回复 - response = client.chat.completions.create( - model=settings['model_name'], - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ], - max_tokens=100, - temperature=0.7 - ) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt} + ] - reply = response.choices[0].message.content.strip() + # 根据API类型选择调用方式 + if self._is_dashscope_api(settings): + logger.info(f"使用DashScope API生成回复") + reply = self._call_dashscope_api(settings, messages, max_tokens=100, temperature=0.7) + else: + logger.info(f"使用OpenAI兼容API生成回复") + client = self.get_client(cookie_id) + if not client: + return None + reply = self._call_openai_api(client, settings, messages, max_tokens=100, temperature=0.7) # 11. 保存对话记录 self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent) @@ -202,6 +308,11 @@ class AIReplyEngine: except Exception as e: logger.error(f"AI回复生成失败 {cookie_id}: {e}") + # 打印更详细的错误信息 + if hasattr(e, 'response') and hasattr(e.response, 'url'): + logger.error(f"请求URL: {e.response.url}") + if hasattr(e, 'request') and hasattr(e.request, 'url'): + logger.error(f"请求URL: {e.request.url}") return None def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]: