mirror of
https://github.com/zhinianboke/xianyu-auto-reply.git
synced 2025-08-29 17:17:38 +08:00
Update ai_reply_engine.py
This commit is contained in:
parent
e49f45ba6e
commit
dd319014e7
@ -7,6 +7,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import requests
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@ -60,17 +61,108 @@ class AIReplyEngine:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"创建OpenAI客户端 {cookie_id}: base_url={settings['base_url']}, api_key={'***' + settings['api_key'][-4:] if settings['api_key'] else 'None'}")
|
||||||
self.clients[cookie_id] = OpenAI(
|
self.clients[cookie_id] = OpenAI(
|
||||||
api_key=settings['api_key'],
|
api_key=settings['api_key'],
|
||||||
base_url=settings['base_url']
|
base_url=settings['base_url']
|
||||||
)
|
)
|
||||||
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端")
|
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端成功,实际base_url: {self.clients[cookie_id].base_url}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
|
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.clients[cookie_id]
|
return self.clients[cookie_id]
|
||||||
|
|
||||||
|
def _is_dashscope_api(self, settings: dict) -> bool:
|
||||||
|
"""判断是否为DashScope API - 只有选择自定义模型时才使用"""
|
||||||
|
model_name = settings.get('model_name', '')
|
||||||
|
base_url = settings.get('base_url', '')
|
||||||
|
|
||||||
|
# 只有当模型名称为"custom"或"自定义"时,才使用DashScope API格式
|
||||||
|
# 其他情况都使用OpenAI兼容格式
|
||||||
|
is_custom_model = model_name.lower() in ['custom', '自定义', 'dashscope', 'qwen-custom']
|
||||||
|
is_dashscope_url = 'dashscope.aliyuncs.com' in base_url
|
||||||
|
|
||||||
|
logger.info(f"API类型判断: model_name={model_name}, is_custom_model={is_custom_model}, is_dashscope_url={is_dashscope_url}")
|
||||||
|
|
||||||
|
return is_custom_model and is_dashscope_url
|
||||||
|
|
||||||
|
def _call_dashscope_api(self, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
|
||||||
|
"""调用DashScope API"""
|
||||||
|
# 提取app_id从base_url
|
||||||
|
base_url = settings['base_url']
|
||||||
|
if '/apps/' in base_url:
|
||||||
|
app_id = base_url.split('/apps/')[-1].split('/')[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("DashScope API URL中未找到app_id")
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
url = f"https://dashscope.aliyuncs.com/api/v1/apps/{app_id}/completion"
|
||||||
|
|
||||||
|
# 构建提示词(将messages合并为单个prompt)
|
||||||
|
system_content = ""
|
||||||
|
user_content = ""
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg['role'] == 'system':
|
||||||
|
system_content = msg['content']
|
||||||
|
elif msg['role'] == 'user':
|
||||||
|
user_content = msg['content']
|
||||||
|
|
||||||
|
# 构建更清晰的prompt格式
|
||||||
|
if system_content and user_content:
|
||||||
|
prompt = f"{system_content}\n\n用户问题:{user_content}\n\n请直接回答用户的问题:"
|
||||||
|
elif user_content:
|
||||||
|
prompt = user_content
|
||||||
|
else:
|
||||||
|
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
|
|
||||||
|
# 构建请求数据
|
||||||
|
data = {
|
||||||
|
"input": {
|
||||||
|
"prompt": prompt
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature
|
||||||
|
},
|
||||||
|
"debug": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings['api_key']}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"DashScope API请求: {url}")
|
||||||
|
logger.info(f"发送的prompt: {prompt}")
|
||||||
|
logger.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope API请求失败: {response.status_code} - {response.text}")
|
||||||
|
raise Exception(f"DashScope API请求失败: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
logger.debug(f"DashScope API响应: {json.dumps(result, ensure_ascii=False)}")
|
||||||
|
|
||||||
|
# 提取回复内容
|
||||||
|
if 'output' in result and 'text' in result['output']:
|
||||||
|
return result['output']['text'].strip()
|
||||||
|
else:
|
||||||
|
raise Exception(f"DashScope API响应格式错误: {result}")
|
||||||
|
|
||||||
|
def _call_openai_api(self, client: OpenAI, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
|
||||||
|
"""调用OpenAI兼容API"""
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=settings['model_name'],
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
def is_ai_enabled(self, cookie_id: str) -> bool:
|
def is_ai_enabled(self, cookie_id: str) -> bool:
|
||||||
"""检查指定账号是否启用AI回复"""
|
"""检查指定账号是否启用AI回复"""
|
||||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||||
@ -78,45 +170,55 @@ class AIReplyEngine:
|
|||||||
|
|
||||||
def detect_intent(self, message: str, cookie_id: str) -> str:
|
def detect_intent(self, message: str, cookie_id: str) -> str:
|
||||||
"""检测用户消息意图"""
|
"""检测用户消息意图"""
|
||||||
client = self.get_client(cookie_id)
|
|
||||||
if not client:
|
|
||||||
return 'default'
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||||
|
if not settings['ai_enabled'] or not settings['api_key']:
|
||||||
|
return 'default'
|
||||||
|
|
||||||
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
|
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
|
||||||
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
|
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
# 打印调试信息
|
||||||
model=settings['model_name'],
|
logger.info(f"AI设置调试 {cookie_id}: base_url={settings['base_url']}, model={settings['model_name']}")
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": classify_prompt},
|
messages = [
|
||||||
{"role": "user", "content": message}
|
{"role": "system", "content": classify_prompt},
|
||||||
],
|
{"role": "user", "content": message}
|
||||||
max_tokens=10,
|
]
|
||||||
temperature=0.1
|
|
||||||
)
|
# 根据API类型选择调用方式
|
||||||
|
if self._is_dashscope_api(settings):
|
||||||
intent = response.choices[0].message.content.strip().lower()
|
logger.info(f"使用DashScope API进行意图检测")
|
||||||
|
response_text = self._call_dashscope_api(settings, messages, max_tokens=10, temperature=0.1)
|
||||||
|
else:
|
||||||
|
logger.info(f"使用OpenAI兼容API进行意图检测")
|
||||||
|
client = self.get_client(cookie_id)
|
||||||
|
if not client:
|
||||||
|
return 'default'
|
||||||
|
logger.info(f"OpenAI客户端base_url: {client.base_url}")
|
||||||
|
response_text = self._call_openai_api(client, settings, messages, max_tokens=10, temperature=0.1)
|
||||||
|
|
||||||
|
intent = response_text.lower()
|
||||||
if intent in ['price', 'tech', 'default']:
|
if intent in ['price', 'tech', 'default']:
|
||||||
return intent
|
return intent
|
||||||
else:
|
else:
|
||||||
return 'default'
|
return 'default'
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"意图检测失败 {cookie_id}: {e}")
|
logger.error(f"意图检测失败 {cookie_id}: {e}")
|
||||||
|
# 打印更详细的错误信息
|
||||||
|
if hasattr(e, 'response') and hasattr(e.response, 'url'):
|
||||||
|
logger.error(f"请求URL: {e.response.url}")
|
||||||
|
if hasattr(e, 'request') and hasattr(e.request, 'url'):
|
||||||
|
logger.error(f"请求URL: {e.request.url}")
|
||||||
return 'default'
|
return 'default'
|
||||||
|
|
||||||
def generate_reply(self, message: str, item_info: dict, chat_id: str,
|
def generate_reply(self, message: str, item_info: dict, chat_id: str,
|
||||||
cookie_id: str, user_id: str, item_id: str) -> Optional[str]:
|
cookie_id: str, user_id: str, item_id: str) -> Optional[str]:
|
||||||
"""生成AI回复"""
|
"""生成AI回复"""
|
||||||
if not self.is_ai_enabled(cookie_id):
|
if not self.is_ai_enabled(cookie_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
client = self.get_client(cookie_id)
|
|
||||||
if not client:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取AI回复设置
|
# 1. 获取AI回复设置
|
||||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||||
@ -177,17 +279,21 @@ class AIReplyEngine:
|
|||||||
请根据以上信息生成回复:"""
|
请根据以上信息生成回复:"""
|
||||||
|
|
||||||
# 10. 调用AI生成回复
|
# 10. 调用AI生成回复
|
||||||
response = client.chat.completions.create(
|
messages = [
|
||||||
model=settings['model_name'],
|
{"role": "system", "content": system_prompt},
|
||||||
messages=[
|
{"role": "user", "content": user_prompt}
|
||||||
{"role": "system", "content": system_prompt},
|
]
|
||||||
{"role": "user", "content": user_prompt}
|
|
||||||
],
|
|
||||||
max_tokens=100,
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
reply = response.choices[0].message.content.strip()
|
# 根据API类型选择调用方式
|
||||||
|
if self._is_dashscope_api(settings):
|
||||||
|
logger.info(f"使用DashScope API生成回复")
|
||||||
|
reply = self._call_dashscope_api(settings, messages, max_tokens=100, temperature=0.7)
|
||||||
|
else:
|
||||||
|
logger.info(f"使用OpenAI兼容API生成回复")
|
||||||
|
client = self.get_client(cookie_id)
|
||||||
|
if not client:
|
||||||
|
return None
|
||||||
|
reply = self._call_openai_api(client, settings, messages, max_tokens=100, temperature=0.7)
|
||||||
|
|
||||||
# 11. 保存对话记录
|
# 11. 保存对话记录
|
||||||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
|
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
|
||||||
@ -202,6 +308,11 @@ class AIReplyEngine:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"AI回复生成失败 {cookie_id}: {e}")
|
logger.error(f"AI回复生成失败 {cookie_id}: {e}")
|
||||||
|
# 打印更详细的错误信息
|
||||||
|
if hasattr(e, 'response') and hasattr(e.response, 'url'):
|
||||||
|
logger.error(f"请求URL: {e.response.url}")
|
||||||
|
if hasattr(e, 'request') and hasattr(e.request, 'url'):
|
||||||
|
logger.error(f"请求URL: {e.request.url}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]:
|
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user