mirror of
https://github.com/zhinianboke/xianyu-auto-reply.git
synced 2025-08-01 12:07:36 +08:00
274 lines
11 KiB
Python
274 lines
11 KiB
Python
"""
|
||
AI回复引擎模块
|
||
集成XianyuAutoAgent的AI回复功能到现有项目中
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
import sqlite3
|
||
from typing import List, Dict, Optional
|
||
from loguru import logger
|
||
from openai import OpenAI
|
||
from db_manager import db_manager
|
||
|
||
|
||
class AIReplyEngine:
|
||
"""AI回复引擎"""
|
||
|
||
def __init__(self):
|
||
self.clients = {} # 存储不同账号的OpenAI客户端
|
||
self.agents = {} # 存储不同账号的Agent实例
|
||
self._init_default_prompts()
|
||
|
||
def _init_default_prompts(self):
|
||
"""初始化默认提示词"""
|
||
self.default_prompts = {
|
||
'classify': '''你是一个意图分类专家,需要判断用户消息的意图类型。
|
||
请根据用户消息内容,返回以下意图之一:
|
||
- price: 价格相关(议价、优惠、降价等)
|
||
- tech: 技术相关(产品参数、使用方法、故障等)
|
||
- default: 其他一般咨询
|
||
|
||
只返回意图类型,不要其他内容。''',
|
||
|
||
'price': '''你是一位经验丰富的销售专家,擅长议价。
|
||
语言要求:简短直接,每句≤10字,总字数≤40字。
|
||
议价策略:
|
||
1. 根据议价次数递减优惠:第1次小幅优惠,第2次中等优惠,第3次最大优惠
|
||
2. 接近最大议价轮数时要坚持底线,强调商品价值
|
||
3. 优惠不能超过设定的最大百分比和金额
|
||
4. 语气要友好但坚定,突出商品优势
|
||
注意:结合商品信息、对话历史和议价设置,给出合适的回复。''',
|
||
|
||
'tech': '''你是一位技术专家,专业解答产品相关问题。
|
||
语言要求:简短专业,每句≤10字,总字数≤40字。
|
||
回答重点:产品功能、使用方法、注意事项。
|
||
注意:基于商品信息回答,避免过度承诺。''',
|
||
|
||
'default': '''你是一位资深电商卖家,提供优质客服。
|
||
语言要求:简短友好,每句≤10字,总字数≤40字。
|
||
回答重点:商品介绍、物流、售后等常见问题。
|
||
注意:结合商品信息,给出实用建议。'''
|
||
}
|
||
|
||
def get_client(self, cookie_id: str) -> Optional[OpenAI]:
|
||
"""获取指定账号的OpenAI客户端"""
|
||
if cookie_id not in self.clients:
|
||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||
if not settings['ai_enabled'] or not settings['api_key']:
|
||
return None
|
||
|
||
try:
|
||
self.clients[cookie_id] = OpenAI(
|
||
api_key=settings['api_key'],
|
||
base_url=settings['base_url']
|
||
)
|
||
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端")
|
||
except Exception as e:
|
||
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
|
||
return None
|
||
|
||
return self.clients[cookie_id]
|
||
|
||
def is_ai_enabled(self, cookie_id: str) -> bool:
|
||
"""检查指定账号是否启用AI回复"""
|
||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||
return settings['ai_enabled']
|
||
|
||
def detect_intent(self, message: str, cookie_id: str) -> str:
|
||
"""检测用户消息意图"""
|
||
client = self.get_client(cookie_id)
|
||
if not client:
|
||
return 'default'
|
||
|
||
try:
|
||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
|
||
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
|
||
|
||
response = client.chat.completions.create(
|
||
model=settings['model_name'],
|
||
messages=[
|
||
{"role": "system", "content": classify_prompt},
|
||
{"role": "user", "content": message}
|
||
],
|
||
max_tokens=10,
|
||
temperature=0.1
|
||
)
|
||
|
||
intent = response.choices[0].message.content.strip().lower()
|
||
if intent in ['price', 'tech', 'default']:
|
||
return intent
|
||
else:
|
||
return 'default'
|
||
|
||
except Exception as e:
|
||
logger.error(f"意图检测失败 {cookie_id}: {e}")
|
||
return 'default'
|
||
|
||
def generate_reply(self, message: str, item_info: dict, chat_id: str,
|
||
cookie_id: str, user_id: str, item_id: str) -> Optional[str]:
|
||
"""生成AI回复"""
|
||
if not self.is_ai_enabled(cookie_id):
|
||
return None
|
||
|
||
client = self.get_client(cookie_id)
|
||
if not client:
|
||
return None
|
||
|
||
try:
|
||
# 1. 获取AI回复设置
|
||
settings = db_manager.get_ai_reply_settings(cookie_id)
|
||
|
||
# 2. 检测意图
|
||
intent = self.detect_intent(message, cookie_id)
|
||
logger.info(f"检测到意图: {intent} (账号: {cookie_id})")
|
||
|
||
# 3. 获取对话历史
|
||
context = self.get_conversation_context(chat_id, cookie_id)
|
||
|
||
# 4. 获取议价次数
|
||
bargain_count = self.get_bargain_count(chat_id, cookie_id)
|
||
|
||
# 5. 检查议价轮数限制
|
||
if intent == "price":
|
||
max_bargain_rounds = settings.get('max_bargain_rounds', 3)
|
||
if bargain_count >= max_bargain_rounds:
|
||
logger.info(f"议价次数已达上限 ({bargain_count}/{max_bargain_rounds}),拒绝继续议价")
|
||
# 返回拒绝议价的回复
|
||
refuse_reply = f"抱歉,这个价格已经是最优惠的了,不能再便宜了哦!"
|
||
# 保存对话记录
|
||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
|
||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "assistant", refuse_reply, intent)
|
||
return refuse_reply
|
||
|
||
# 6. 构建提示词
|
||
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
|
||
system_prompt = custom_prompts.get(intent, self.default_prompts[intent])
|
||
|
||
# 7. 构建商品信息
|
||
item_desc = f"商品标题: {item_info.get('title', '未知')}\n"
|
||
item_desc += f"商品价格: {item_info.get('price', '未知')}元\n"
|
||
item_desc += f"商品描述: {item_info.get('desc', '无')}"
|
||
|
||
# 8. 构建对话历史
|
||
context_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in context[-10:]]) # 最近10条
|
||
|
||
# 9. 构建用户消息
|
||
max_bargain_rounds = settings.get('max_bargain_rounds', 3)
|
||
max_discount_percent = settings.get('max_discount_percent', 10)
|
||
max_discount_amount = settings.get('max_discount_amount', 100)
|
||
|
||
user_prompt = f"""商品信息:
|
||
{item_desc}
|
||
|
||
对话历史:
|
||
{context_str}
|
||
|
||
议价设置:
|
||
- 当前议价次数:{bargain_count}
|
||
- 最大议价轮数:{max_bargain_rounds}
|
||
- 最大优惠百分比:{max_discount_percent}%
|
||
- 最大优惠金额:{max_discount_amount}元
|
||
|
||
用户消息:{message}
|
||
|
||
请根据以上信息生成回复:"""
|
||
|
||
# 10. 调用AI生成回复
|
||
response = client.chat.completions.create(
|
||
model=settings['model_name'],
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
max_tokens=100,
|
||
temperature=0.7
|
||
)
|
||
|
||
reply = response.choices[0].message.content.strip()
|
||
|
||
# 11. 保存对话记录
|
||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
|
||
self.save_conversation(chat_id, cookie_id, user_id, item_id, "assistant", reply, intent)
|
||
|
||
# 12. 更新议价次数
|
||
if intent == "price":
|
||
self.increment_bargain_count(chat_id, cookie_id)
|
||
|
||
logger.info(f"AI回复生成成功 (账号: {cookie_id}): {reply}")
|
||
return reply
|
||
|
||
except Exception as e:
|
||
logger.error(f"AI回复生成失败 {cookie_id}: {e}")
|
||
return None
|
||
|
||
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]:
|
||
"""获取对话上下文"""
|
||
try:
|
||
with db_manager.lock:
|
||
cursor = db_manager.conn.cursor()
|
||
cursor.execute('''
|
||
SELECT role, content FROM ai_conversations
|
||
WHERE chat_id = ? AND cookie_id = ?
|
||
ORDER BY created_at DESC LIMIT ?
|
||
''', (chat_id, cookie_id, limit))
|
||
|
||
results = cursor.fetchall()
|
||
# 反转顺序,使其按时间正序
|
||
context = [{"role": row[0], "content": row[1]} for row in reversed(results)]
|
||
return context
|
||
except Exception as e:
|
||
logger.error(f"获取对话上下文失败: {e}")
|
||
return []
|
||
|
||
def save_conversation(self, chat_id: str, cookie_id: str, user_id: str,
|
||
item_id: str, role: str, content: str, intent: str = None):
|
||
"""保存对话记录"""
|
||
try:
|
||
with db_manager.lock:
|
||
cursor = db_manager.conn.cursor()
|
||
cursor.execute('''
|
||
INSERT INTO ai_conversations
|
||
(cookie_id, chat_id, user_id, item_id, role, content, intent)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
''', (cookie_id, chat_id, user_id, item_id, role, content, intent))
|
||
db_manager.conn.commit()
|
||
except Exception as e:
|
||
logger.error(f"保存对话记录失败: {e}")
|
||
|
||
def get_bargain_count(self, chat_id: str, cookie_id: str) -> int:
|
||
"""获取议价次数"""
|
||
try:
|
||
with db_manager.lock:
|
||
cursor = db_manager.conn.cursor()
|
||
cursor.execute('''
|
||
SELECT COUNT(*) FROM ai_conversations
|
||
WHERE chat_id = ? AND cookie_id = ? AND intent = 'price' AND role = 'user'
|
||
''', (chat_id, cookie_id))
|
||
|
||
result = cursor.fetchone()
|
||
return result[0] if result else 0
|
||
except Exception as e:
|
||
logger.error(f"获取议价次数失败: {e}")
|
||
return 0
|
||
|
||
def increment_bargain_count(self, chat_id: str, cookie_id: str):
|
||
"""增加议价次数(通过保存记录自动增加)"""
|
||
# 议价次数通过查询price意图的用户消息数量来计算,无需单独操作
|
||
pass
|
||
|
||
def clear_client_cache(self, cookie_id: str = None):
|
||
"""清理客户端缓存"""
|
||
if cookie_id:
|
||
self.clients.pop(cookie_id, None)
|
||
logger.info(f"清理账号 {cookie_id} 的客户端缓存")
|
||
else:
|
||
self.clients.clear()
|
||
logger.info("清理所有客户端缓存")
|
||
|
||
|
||
# 全局AI回复引擎实例
|
||
ai_reply_engine = AIReplyEngine()
|