mirror of
https://github.com/zhinianboke/xianyu-auto-reply.git
synced 2025-08-29 17:17:38 +08:00
Update ai_reply_engine.py
This commit is contained in:
parent
e49f45ba6e
commit
dd319014e7
@ -7,6 +7,7 @@ import os
|
||||
import json
|
||||
import time
|
||||
import sqlite3
|
||||
import requests
|
||||
from typing import List, Dict, Optional
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
@ -60,17 +61,108 @@ class AIReplyEngine:
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.info(f"创建OpenAI客户端 {cookie_id}: base_url={settings['base_url']}, api_key={'***' + settings['api_key'][-4:] if settings['api_key'] else 'None'}")
|
||||
self.clients[cookie_id] = OpenAI(
|
||||
api_key=settings['api_key'],
|
||||
base_url=settings['base_url']
|
||||
)
|
||||
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端")
|
||||
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端成功,实际base_url: {self.clients[cookie_id].base_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
|
||||
return None
|
||||
|
||||
return self.clients[cookie_id]
|
||||
|
||||
def _is_dashscope_api(self, settings: dict) -> bool:
|
||||
"""判断是否为DashScope API - 只有选择自定义模型时才使用"""
|
||||
model_name = settings.get('model_name', '')
|
||||
base_url = settings.get('base_url', '')
|
||||
|
||||
# 只有当模型名称为"custom"或"自定义"时,才使用DashScope API格式
|
||||
# 其他情况都使用OpenAI兼容格式
|
||||
is_custom_model = model_name.lower() in ['custom', '自定义', 'dashscope', 'qwen-custom']
|
||||
is_dashscope_url = 'dashscope.aliyuncs.com' in base_url
|
||||
|
||||
logger.info(f"API类型判断: model_name={model_name}, is_custom_model={is_custom_model}, is_dashscope_url={is_dashscope_url}")
|
||||
|
||||
return is_custom_model and is_dashscope_url
|
||||
|
||||
def _call_dashscope_api(self, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
|
||||
"""调用DashScope API"""
|
||||
# 提取app_id从base_url
|
||||
base_url = settings['base_url']
|
||||
if '/apps/' in base_url:
|
||||
app_id = base_url.split('/apps/')[-1].split('/')[0]
|
||||
else:
|
||||
raise ValueError("DashScope API URL中未找到app_id")
|
||||
|
||||
# 构建请求URL
|
||||
url = f"https://dashscope.aliyuncs.com/api/v1/apps/{app_id}/completion"
|
||||
|
||||
# 构建提示词(将messages合并为单个prompt)
|
||||
system_content = ""
|
||||
user_content = ""
|
||||
|
||||
for msg in messages:
|
||||
if msg['role'] == 'system':
|
||||
system_content = msg['content']
|
||||
elif msg['role'] == 'user':
|
||||
user_content = msg['content']
|
||||
|
||||
# 构建更清晰的prompt格式
|
||||
if system_content and user_content:
|
||||
prompt = f"{system_content}\n\n用户问题:{user_content}\n\n请直接回答用户的问题:"
|
||||
elif user_content:
|
||||
prompt = user_content
|
||||
else:
|
||||
prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
|
||||
# 构建请求数据
|
||||
data = {
|
||||
"input": {
|
||||
"prompt": prompt
|
||||
},
|
||||
"parameters": {
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature
|
||||
},
|
||||
"debug": {}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings['api_key']}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
logger.info(f"DashScope API请求: {url}")
|
||||
logger.info(f"发送的prompt: {prompt}")
|
||||
logger.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope API请求失败: {response.status_code} - {response.text}")
|
||||
raise Exception(f"DashScope API请求失败: {response.status_code} - {response.text}")
|
||||
|
||||
result = response.json()
|
||||
logger.debug(f"DashScope API响应: {json.dumps(result, ensure_ascii=False)}")
|
||||
|
||||
# 提取回复内容
|
||||
if 'output' in result and 'text' in result['output']:
|
||||
return result['output']['text'].strip()
|
||||
else:
|
||||
raise Exception(f"DashScope API响应格式错误: {result}")
|
||||
|
||||
def _call_openai_api(self, client: OpenAI, settings: dict, messages: list, max_tokens: int = 100, temperature: float = 0.7) -> str:
|
||||
"""调用OpenAI兼容API"""
|
||||
response = client.chat.completions.create(
|
||||
model=settings['model_name'],
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
def is_ai_enabled(self, cookie_id: str) -> bool:
|
||||
"""检查指定账号是否启用AI回复"""
|
||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||
@ -78,26 +170,35 @@ class AIReplyEngine:
|
||||
|
||||
def detect_intent(self, message: str, cookie_id: str) -> str:
|
||||
"""检测用户消息意图"""
|
||||
client = self.get_client(cookie_id)
|
||||
if not client:
|
||||
return 'default'
|
||||
|
||||
try:
|
||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||
if not settings['ai_enabled'] or not settings['api_key']:
|
||||
return 'default'
|
||||
|
||||
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
|
||||
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=settings['model_name'],
|
||||
messages=[
|
||||
# 打印调试信息
|
||||
logger.info(f"AI设置调试 {cookie_id}: base_url={settings['base_url']}, model={settings['model_name']}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": classify_prompt},
|
||||
{"role": "user", "content": message}
|
||||
],
|
||||
max_tokens=10,
|
||||
temperature=0.1
|
||||
)
|
||||
]
|
||||
|
||||
intent = response.choices[0].message.content.strip().lower()
|
||||
# 根据API类型选择调用方式
|
||||
if self._is_dashscope_api(settings):
|
||||
logger.info(f"使用DashScope API进行意图检测")
|
||||
response_text = self._call_dashscope_api(settings, messages, max_tokens=10, temperature=0.1)
|
||||
else:
|
||||
logger.info(f"使用OpenAI兼容API进行意图检测")
|
||||
client = self.get_client(cookie_id)
|
||||
if not client:
|
||||
return 'default'
|
||||
logger.info(f"OpenAI客户端base_url: {client.base_url}")
|
||||
response_text = self._call_openai_api(client, settings, messages, max_tokens=10, temperature=0.1)
|
||||
|
||||
intent = response_text.lower()
|
||||
if intent in ['price', 'tech', 'default']:
|
||||
return intent
|
||||
else:
|
||||
@ -105,6 +206,11 @@ class AIReplyEngine:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"意图检测失败 {cookie_id}: {e}")
|
||||
# 打印更详细的错误信息
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'url'):
|
||||
logger.error(f"请求URL: {e.response.url}")
|
||||
if hasattr(e, 'request') and hasattr(e.request, 'url'):
|
||||
logger.error(f"请求URL: {e.request.url}")
|
||||
return 'default'
|
||||
|
||||
def generate_reply(self, message: str, item_info: dict, chat_id: str,
|
||||
@ -113,10 +219,6 @@ class AIReplyEngine:
|
||||
if not self.is_ai_enabled(cookie_id):
|
||||
return None
|
||||
|
||||
client = self.get_client(cookie_id)
|
||||
if not client:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 获取AI回复设置
|
||||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||||
@ -177,17 +279,21 @@ class AIReplyEngine:
|
||||
请根据以上信息生成回复:"""
|
||||
|
||||
# 10. 调用AI生成回复
|
||||
response = client.chat.completions.create(
|
||||
model=settings['model_name'],
|
||||
messages=[
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
max_tokens=100,
|
||||
temperature=0.7
|
||||
)
|
||||
]
|
||||
|
||||
reply = response.choices[0].message.content.strip()
|
||||
# 根据API类型选择调用方式
|
||||
if self._is_dashscope_api(settings):
|
||||
logger.info(f"使用DashScope API生成回复")
|
||||
reply = self._call_dashscope_api(settings, messages, max_tokens=100, temperature=0.7)
|
||||
else:
|
||||
logger.info(f"使用OpenAI兼容API生成回复")
|
||||
client = self.get_client(cookie_id)
|
||||
if not client:
|
||||
return None
|
||||
reply = self._call_openai_api(client, settings, messages, max_tokens=100, temperature=0.7)
|
||||
|
||||
# 11. 保存对话记录
|
||||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
|
||||
@ -202,6 +308,11 @@ class AIReplyEngine:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI回复生成失败 {cookie_id}: {e}")
|
||||
# 打印更详细的错误信息
|
||||
if hasattr(e, 'response') and hasattr(e.response, 'url'):
|
||||
logger.error(f"请求URL: {e.response.url}")
|
||||
if hasattr(e, 'request') and hasattr(e.request, 'url'):
|
||||
logger.error(f"请求URL: {e.request.url}")
|
||||
return None
|
||||
|
||||
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user