xianyu-auto-reply/ai_reply_engine.py
2025-07-24 14:16:18 +08:00

274 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI回复引擎模块
集成XianyuAutoAgent的AI回复功能到现有项目中
"""
import os
import json
import time
import sqlite3
from typing import List, Dict, Optional
from loguru import logger
from openai import OpenAI
from db_manager import db_manager
class AIReplyEngine:
"""AI回复引擎"""
def __init__(self):
self.clients = {} # 存储不同账号的OpenAI客户端
self.agents = {} # 存储不同账号的Agent实例
self._init_default_prompts()
def _init_default_prompts(self):
"""初始化默认提示词"""
self.default_prompts = {
'classify': '''你是一个意图分类专家,需要判断用户消息的意图类型。
请根据用户消息内容,返回以下意图之一:
- price: 价格相关(议价、优惠、降价等)
- tech: 技术相关(产品参数、使用方法、故障等)
- default: 其他一般咨询
只返回意图类型,不要其他内容。''',
'price': '''你是一位经验丰富的销售专家,擅长议价。
语言要求简短直接每句≤10字总字数≤40字。
议价策略:
1. 根据议价次数递减优惠第1次小幅优惠第2次中等优惠第3次最大优惠
2. 接近最大议价轮数时要坚持底线,强调商品价值
3. 优惠不能超过设定的最大百分比和金额
4. 语气要友好但坚定,突出商品优势
注意:结合商品信息、对话历史和议价设置,给出合适的回复。''',
'tech': '''你是一位技术专家,专业解答产品相关问题。
语言要求简短专业每句≤10字总字数≤40字。
回答重点:产品功能、使用方法、注意事项。
注意:基于商品信息回答,避免过度承诺。''',
'default': '''你是一位资深电商卖家,提供优质客服。
语言要求简短友好每句≤10字总字数≤40字。
回答重点:商品介绍、物流、售后等常见问题。
注意:结合商品信息,给出实用建议。'''
}
def get_client(self, cookie_id: str) -> Optional[OpenAI]:
"""获取指定账号的OpenAI客户端"""
if cookie_id not in self.clients:
settings = db_manager.get_ai_reply_settings(cookie_id)
if not settings['ai_enabled'] or not settings['api_key']:
return None
try:
self.clients[cookie_id] = OpenAI(
api_key=settings['api_key'],
base_url=settings['base_url']
)
logger.info(f"为账号 {cookie_id} 创建OpenAI客户端")
except Exception as e:
logger.error(f"创建OpenAI客户端失败 {cookie_id}: {e}")
return None
return self.clients[cookie_id]
def is_ai_enabled(self, cookie_id: str) -> bool:
"""检查指定账号是否启用AI回复"""
settings = db_manager.get_ai_reply_settings(cookie_id)
return settings['ai_enabled']
def detect_intent(self, message: str, cookie_id: str) -> str:
"""检测用户消息意图"""
client = self.get_client(cookie_id)
if not client:
return 'default'
try:
settings = db_manager.get_ai_reply_settings(cookie_id)
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
classify_prompt = custom_prompts.get('classify', self.default_prompts['classify'])
response = client.chat.completions.create(
model=settings['model_name'],
messages=[
{"role": "system", "content": classify_prompt},
{"role": "user", "content": message}
],
max_tokens=10,
temperature=0.1
)
intent = response.choices[0].message.content.strip().lower()
if intent in ['price', 'tech', 'default']:
return intent
else:
return 'default'
except Exception as e:
logger.error(f"意图检测失败 {cookie_id}: {e}")
return 'default'
def generate_reply(self, message: str, item_info: dict, chat_id: str,
cookie_id: str, user_id: str, item_id: str) -> Optional[str]:
"""生成AI回复"""
if not self.is_ai_enabled(cookie_id):
return None
client = self.get_client(cookie_id)
if not client:
return None
try:
# 1. 获取AI回复设置
settings = db_manager.get_ai_reply_settings(cookie_id)
# 2. 检测意图
intent = self.detect_intent(message, cookie_id)
logger.info(f"检测到意图: {intent} (账号: {cookie_id})")
# 3. 获取对话历史
context = self.get_conversation_context(chat_id, cookie_id)
# 4. 获取议价次数
bargain_count = self.get_bargain_count(chat_id, cookie_id)
# 5. 检查议价轮数限制
if intent == "price":
max_bargain_rounds = settings.get('max_bargain_rounds', 3)
if bargain_count >= max_bargain_rounds:
logger.info(f"议价次数已达上限 ({bargain_count}/{max_bargain_rounds}),拒绝继续议价")
# 返回拒绝议价的回复
refuse_reply = f"抱歉,这个价格已经是最优惠的了,不能再便宜了哦!"
# 保存对话记录
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
self.save_conversation(chat_id, cookie_id, user_id, item_id, "assistant", refuse_reply, intent)
return refuse_reply
# 6. 构建提示词
custom_prompts = json.loads(settings['custom_prompts']) if settings['custom_prompts'] else {}
system_prompt = custom_prompts.get(intent, self.default_prompts[intent])
# 7. 构建商品信息
item_desc = f"商品标题: {item_info.get('title', '未知')}\n"
item_desc += f"商品价格: {item_info.get('price', '未知')}\n"
item_desc += f"商品描述: {item_info.get('desc', '')}"
# 8. 构建对话历史
context_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in context[-10:]]) # 最近10条
# 9. 构建用户消息
max_bargain_rounds = settings.get('max_bargain_rounds', 3)
max_discount_percent = settings.get('max_discount_percent', 10)
max_discount_amount = settings.get('max_discount_amount', 100)
user_prompt = f"""商品信息:
{item_desc}
对话历史:
{context_str}
议价设置:
- 当前议价次数:{bargain_count}
- 最大议价轮数:{max_bargain_rounds}
- 最大优惠百分比:{max_discount_percent}%
- 最大优惠金额:{max_discount_amount}
用户消息:{message}
请根据以上信息生成回复:"""
# 10. 调用AI生成回复
response = client.chat.completions.create(
model=settings['model_name'],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=100,
temperature=0.7
)
reply = response.choices[0].message.content.strip()
# 11. 保存对话记录
self.save_conversation(chat_id, cookie_id, user_id, item_id, "user", message, intent)
self.save_conversation(chat_id, cookie_id, user_id, item_id, "assistant", reply, intent)
# 12. 更新议价次数
if intent == "price":
self.increment_bargain_count(chat_id, cookie_id)
logger.info(f"AI回复生成成功 (账号: {cookie_id}): {reply}")
return reply
except Exception as e:
logger.error(f"AI回复生成失败 {cookie_id}: {e}")
return None
def get_conversation_context(self, chat_id: str, cookie_id: str, limit: int = 20) -> List[Dict]:
"""获取对话上下文"""
try:
with db_manager.lock:
cursor = db_manager.conn.cursor()
cursor.execute('''
SELECT role, content FROM ai_conversations
WHERE chat_id = ? AND cookie_id = ?
ORDER BY created_at DESC LIMIT ?
''', (chat_id, cookie_id, limit))
results = cursor.fetchall()
# 反转顺序,使其按时间正序
context = [{"role": row[0], "content": row[1]} for row in reversed(results)]
return context
except Exception as e:
logger.error(f"获取对话上下文失败: {e}")
return []
def save_conversation(self, chat_id: str, cookie_id: str, user_id: str,
item_id: str, role: str, content: str, intent: str = None):
"""保存对话记录"""
try:
with db_manager.lock:
cursor = db_manager.conn.cursor()
cursor.execute('''
INSERT INTO ai_conversations
(cookie_id, chat_id, user_id, item_id, role, content, intent)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (cookie_id, chat_id, user_id, item_id, role, content, intent))
db_manager.conn.commit()
except Exception as e:
logger.error(f"保存对话记录失败: {e}")
def get_bargain_count(self, chat_id: str, cookie_id: str) -> int:
"""获取议价次数"""
try:
with db_manager.lock:
cursor = db_manager.conn.cursor()
cursor.execute('''
SELECT COUNT(*) FROM ai_conversations
WHERE chat_id = ? AND cookie_id = ? AND intent = 'price' AND role = 'user'
''', (chat_id, cookie_id))
result = cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"获取议价次数失败: {e}")
return 0
def increment_bargain_count(self, chat_id: str, cookie_id: str):
"""增加议价次数(通过保存记录自动增加)"""
# 议价次数通过查询price意图的用户消息数量来计算无需单独操作
pass
def clear_client_cache(self, cookie_id: str = None):
"""清理客户端缓存"""
if cookie_id:
self.clients.pop(cookie_id, None)
logger.info(f"清理账号 {cookie_id} 的客户端缓存")
else:
self.clients.clear()
logger.info("清理所有客户端缓存")
# 全局AI回复引擎实例
ai_reply_engine = AIReplyEngine()