feat: 修复 H5 底部导航覆盖 + 更新项目进度文档
## H5 底部导航修复 (Bug #10) - 精简 App.vue,移除重复 tabbar,仅保留全局样式 - uni-page 设置 height: calc(100% - 50px) + overflow-y: auto - 内容区域精确停在底部导航上方,独立滚动不再叠加 - 恢复 custom-tab-bar 组件 ## 项目进度文档 - PROGRESS.md 更新至 10 个 Bug 修复 - 新增 H5 底部导航修复记录 - 新增历史变更条目
This commit is contained in:
@@ -40,6 +40,10 @@ EXCHANGE_RATE_API_KEY=
|
||||
UPLOAD_DIR=./uploads
|
||||
MAX_UPLOAD_SIZE=10485760
|
||||
|
||||
# 错误监控 (Sentry)
|
||||
SENTRY_DSN=
|
||||
DEBUG=true
|
||||
|
||||
# URL
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
BACKEND_URL=http://localhost:8000
|
||||
|
||||
@@ -6,6 +6,12 @@ RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
libpq-dev \
|
||||
libcairo2 \
|
||||
libpango-1.0-0 \
|
||||
libpangocairo-1.0-0 \
|
||||
libgdk-pixbuf2.0-0 \
|
||||
libffi-dev \
|
||||
shared-mime-info \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
@@ -13,7 +13,7 @@ if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
from app.database import Base
|
||||
from app.models import User, Product, Customer, Conversation, Message, Quotation, QuotationItem, CorpusEntry
|
||||
from app.models import User, Product, Customer, Conversation, Message, Quotation, QuotationItem, CorpusEntry, Team, TeamMember, UsageLog, Notification, Feedback, Subscription, PreferenceAnalysis, MarketingEffect, Device, FollowupStrategy, FollowupLog
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ def upgrade() -> None:
|
||||
sa.Column('username', sa.String(length=100), nullable=True),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tier', sa.String(length=50), nullable=True),
|
||||
sa.Column('role', sa.String(length=20), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
@@ -171,7 +172,7 @@ def upgrade() -> None:
|
||||
sa.Column('user_edited', sa.Boolean(), nullable=True),
|
||||
sa.Column('user_rating', sa.Integer(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=True),
|
||||
sa.Column('embedding', postgresql.Vector(length=768), nullable=True),
|
||||
sa.Column('embedding', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""add teams and analytics tables
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-05-09
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '002'
|
||||
down_revision: Union[str, None] = '001'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('teams',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('owner_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('member_count', sa.Integer(), nullable=True),
|
||||
sa.Column('max_members', sa.Integer(), nullable=True),
|
||||
sa.Column('tier', sa.String(length=50), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_teams_owner_id'), 'teams', ['owner_id'], unique=False)
|
||||
|
||||
op.create_table('team_members',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('team_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role', sa.String(length=50), nullable=True),
|
||||
sa.Column('invited_by', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('joined_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['team_id'], ['teams.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_team_members_team_id'), 'team_members', ['team_id'], unique=False)
|
||||
op.create_index(op.f('ix_team_members_user_id'), 'team_members', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('usage_logs',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('team_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('action', sa.String(length=100), nullable=False),
|
||||
sa.Column('detail', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=50), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_usage_logs_user_id'), 'usage_logs', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_usage_logs_team_id'), 'usage_logs', ['team_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('usage_logs')
|
||||
op.drop_table('team_members')
|
||||
op.drop_table('teams')
|
||||
@@ -0,0 +1,107 @@
|
||||
"""add notifications, feedback, subscription, and p3 tables
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2026-05-09
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '003'
|
||||
down_revision: Union[str, None] = '002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('notifications',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('title', sa.String(length=255), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('notification_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('reference_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('reference_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('is_read', sa.Boolean(), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_notifications_user_id'), 'notifications', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('feedbacks',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('category', sa.String(length=50), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('contact', sa.String(length=100), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_feedbacks_user_id'), 'feedbacks', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('subscriptions',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('plan', sa.String(length=50), nullable=False),
|
||||
sa.Column('status', sa.String(length=20), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('auto_renew', sa.Boolean(), nullable=True),
|
||||
sa.Column('payment_provider', sa.String(length=50), nullable=True),
|
||||
sa.Column('payment_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('amount', sa.Float(), nullable=True),
|
||||
sa.Column('currency', sa.String(length=10), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_subscriptions_user_id'), 'subscriptions', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('preference_analyses',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('task_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('preferred_tone', sa.String(length=50), nullable=True),
|
||||
sa.Column('preferred_style', sa.String(length=50), nullable=True),
|
||||
sa.Column('common_replacements', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('avg_formality_score', sa.Float(), nullable=True),
|
||||
sa.Column('greeting_style', sa.String(length=100), nullable=True),
|
||||
sa.Column('sign_off_style', sa.String(length=100), nullable=True),
|
||||
sa.Column('analysis_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('confidence', sa.Float(), nullable=True),
|
||||
sa.Column('interaction_count', sa.Integer(), nullable=True),
|
||||
sa.Column('last_analysis_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_preference_analyses_user_id'), 'preference_analyses', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('marketing_effects',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('content_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('product_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('product_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('channel', sa.String(length=50), nullable=True),
|
||||
sa.Column('event_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('target_audience', sa.String(length=255), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_marketing_effects_user_id'), 'marketing_effects', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_marketing_effects_content_hash'), 'marketing_effects', ['content_hash'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('marketing_effects')
|
||||
op.drop_table('preference_analyses')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_table('feedbacks')
|
||||
op.drop_table('notifications')
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add devices table for push notification registration
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2026-05-10
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '004'
|
||||
down_revision: Union[str, None] = '003'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('devices',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('platform', sa.String(length=50), nullable=True),
|
||||
sa.Column('push_token', sa.String(length=500), nullable=True),
|
||||
sa.Column('client_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_devices_user_id'), 'devices', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('devices')
|
||||
@@ -0,0 +1,66 @@
|
||||
"""add followup_strategies and followup_logs tables
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2026-05-10
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '005'
|
||||
down_revision: Union[str, None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('followup_strategies',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('trigger_condition', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('channel', sa.String(length=50), nullable=True),
|
||||
sa.Column('ai_prompt_template', sa.Text(), nullable=True),
|
||||
sa.Column('priority', sa.Integer(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
op.create_table('followup_logs',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('customer_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('strategy_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('channel', sa.String(length=50), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('ai_generated_content', sa.Text(), nullable=True),
|
||||
sa.Column('user_edited_content', sa.Text(), nullable=True),
|
||||
sa.Column('health_score_at_time', sa.Integer(), nullable=True),
|
||||
sa.Column('silence_days_at_time', sa.Integer(), nullable=True),
|
||||
sa.Column('sent_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('replied_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('response_status', sa.String(length=50), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_followup_logs_user_id'), 'followup_logs', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_followup_logs_customer_id'), 'followup_logs', ['customer_id'], unique=False)
|
||||
op.create_foreign_key('fk_followup_logs_customer', 'followup_logs', 'customers', ['customer_id'], ['id'])
|
||||
op.create_foreign_key('fk_followup_logs_strategy', 'followup_logs', 'followup_strategies', ['strategy_id'], ['id'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint('fk_followup_logs_strategy', 'followup_logs', type_='foreignkey')
|
||||
op.drop_constraint('fk_followup_logs_customer', 'followup_logs', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_followup_logs_customer_id'), table_name='followup_logs')
|
||||
op.drop_index(op.f('ix_followup_logs_user_id'), table_name='followup_logs')
|
||||
op.drop_table('followup_logs')
|
||||
op.drop_table('followup_strategies')
|
||||
@@ -13,7 +13,7 @@ class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def reply(
|
||||
self, inquiry: str, context: Optional[Dict[str, Any]] = None,
|
||||
tone: str = "professional",
|
||||
tone: str = "professional", preference_context: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@@ -21,6 +21,7 @@ class AIProvider(ABC):
|
||||
async def generate_marketing(
|
||||
self, product_info: Dict[str, Any], target: str,
|
||||
style: str = "professional", language: str = "en",
|
||||
preference_context: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,5 +2,7 @@ from .openai import OpenAIProvider
|
||||
from .claude import ClaudeProvider
|
||||
from .deepl import DeepLProvider
|
||||
from .local import LocalProvider
|
||||
from .spark import SparkProvider
|
||||
from .sensenova import SensenovaProvider
|
||||
|
||||
__all__ = ["OpenAIProvider", "ClaudeProvider", "DeepLProvider", "LocalProvider"]
|
||||
__all__ = ["OpenAIProvider", "ClaudeProvider", "DeepLProvider", "LocalProvider", "SparkProvider", "SensenovaProvider"]
|
||||
|
||||
@@ -32,8 +32,10 @@ class ClaudeProvider(AIProvider):
|
||||
content = await self._call(system, prompt)
|
||||
return {"translated_text": content, "provider": self.name}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"]
|
||||
if preference_context:
|
||||
system += f"\nUser writing preference: {preference_context}"
|
||||
context_str = ""
|
||||
if context:
|
||||
for k, v in context.items():
|
||||
@@ -43,8 +45,10 @@ class ClaudeProvider(AIProvider):
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"]
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
info = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product:\n{info}\n\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nWrite marketing copy:"
|
||||
content = await self._call(system, prompt, max_tokens=1500)
|
||||
|
||||
@@ -14,17 +14,22 @@ class LocalProvider(AIProvider):
|
||||
result = await self._generate(prompt)
|
||||
return {"translated_text": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
ctx = ""
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
prompt = ""
|
||||
if preference_context:
|
||||
prompt += f"[User prefers: {preference_context}]\n"
|
||||
if context:
|
||||
ctx = "\n".join(f"{k}: {v}" for k, v in context.items() if v)
|
||||
prompt = f"{ctx}\nCustomer: {inquiry}\n\nWrite a {tone} reply:"
|
||||
prompt += "\n".join(f"{k}: {v}" for k, v in context.items() if v) + "\n"
|
||||
prompt += f"Customer: {inquiry}\n\nWrite a {tone} reply:"
|
||||
result = await self._generate(prompt)
|
||||
return {"reply": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
info = json.dumps(product_info, ensure_ascii=False)
|
||||
prompt = f"Product: {info}\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nMarketing copy:"
|
||||
prompt = ""
|
||||
if preference_context:
|
||||
prompt += f"[User prefers: {preference_context}]\n"
|
||||
prompt += f"Product: {info}\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nMarketing copy:"
|
||||
result = await self._generate(prompt, max_tokens=800)
|
||||
return {"content": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
|
||||
@@ -19,8 +19,11 @@ SYSTEM_PROMPTS = {
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o"):
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: Optional[str] = None):
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncOpenAI(**kwargs)
|
||||
self.model = model
|
||||
self._name = f"openai-{model}"
|
||||
self._pricing = {
|
||||
@@ -39,8 +42,10 @@ class OpenAIProvider(AIProvider):
|
||||
content = await self._call(system, f"Translate to {target_lang}:\n\n{text}", model=self._cheap_model)
|
||||
return {"translated_text": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"] + f"\nTone: {tone}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
|
||||
context_str = ""
|
||||
if context:
|
||||
@@ -57,8 +62,10 @@ class OpenAIProvider(AIProvider):
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"] + f"\nStyle: {style}\nTarget audience: {target}\nLanguage: {language}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
|
||||
product_str = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product information:\n{product_str}\n\nGenerate marketing copy:"
|
||||
@@ -76,7 +83,7 @@ class OpenAIProvider(AIProvider):
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "error": "parse_failed"}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000, response_format: Optional[Dict] = None, model: Optional[str] = None) -> str:
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 3000, response_format: Optional[Dict] = None, model: Optional[str] = None) -> str:
|
||||
kwargs = {
|
||||
"model": model or self.model,
|
||||
"messages": [
|
||||
@@ -90,7 +97,46 @@ class OpenAIProvider(AIProvider):
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
resp = await self.client.chat.completions.create(**kwargs)
|
||||
return resp.choices[0].message.content
|
||||
content = resp.choices[0].message.content
|
||||
|
||||
if content is None and hasattr(resp.choices[0].message, 'reasoning'):
|
||||
reasoning = resp.choices[0].message.reasoning
|
||||
if reasoning:
|
||||
import re
|
||||
final_output_patterns = [
|
||||
r'Final Output Generation[::]\s*(.+?)(?:\n\n|$)',
|
||||
r'Final Output[::]\s*(.+?)(?:\n\n|$)',
|
||||
r'7\.\s*Final Output Generation[::]\s*(.+?)(?:\n\n|$)',
|
||||
r'翻译结果[::]\s*(.+?)(?:\n\n|$)',
|
||||
r'最终输出[::]\s*(.+?)(?:\n\n|$)',
|
||||
]
|
||||
for pattern in final_output_patterns:
|
||||
match = re.search(pattern, reasoning, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
break
|
||||
|
||||
if content is None:
|
||||
paragraphs = re.split(r'\n\n+', reasoning.strip())
|
||||
if paragraphs:
|
||||
for p in reversed(paragraphs):
|
||||
p = p.strip()
|
||||
if p and len(p) > 10:
|
||||
if not p.startswith('步骤') and not p.startswith('Step'):
|
||||
content = p
|
||||
break
|
||||
|
||||
if content is None and hasattr(resp.choices[0].message, 'reasoning'):
|
||||
reasoning = resp.choices[0].message.reasoning
|
||||
if reasoning:
|
||||
import re
|
||||
cleaned = re.sub(r'^步骤\d+[::].*$', '', reasoning, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r'^Step \d+[::].*$', '', cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r'\n+', '\n', cleaned).strip()
|
||||
if cleaned:
|
||||
content = cleaned
|
||||
|
||||
return content
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from app.ai.providers.openai import OpenAIProvider
|
||||
|
||||
|
||||
class SensenovaProvider(OpenAIProvider):
|
||||
def __init__(self, api_key: str, model: str = "sensenova-6.7-flash-lite", base_url: str = "https://token.sensenova.cn/v1"):
|
||||
super().__init__(api_key=api_key, model=model, base_url=base_url)
|
||||
self._name = f"sensenova-{model}"
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"translate": "You are a professional translator specialized in foreign trade. "
|
||||
"Translate business terms accurately. Return ONLY the translated text.",
|
||||
"reply": "You are an experienced foreign trade sales expert. Write professional, "
|
||||
"clear business replies. Return ONLY the reply text.",
|
||||
"marketing": "You are a creative copywriter for international trade. "
|
||||
"Return ONLY the marketing copy, no explanations.",
|
||||
"extract": "Extract structured data from text. Return ONLY valid JSON.",
|
||||
}
|
||||
|
||||
|
||||
class SparkProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "astron-code-latest", base_url: str = None):
|
||||
from app.config import settings
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or settings.IFLYTEK_API_BASE,
|
||||
)
|
||||
self.model = model
|
||||
self._name = f"spark-{model}"
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: {context}"
|
||||
prompt = f"Translate {f'from {source_lang} ' if source_lang and source_lang != 'auto' else ''}to {target_lang}:\n\n{text}"
|
||||
content = await self._call(system, prompt)
|
||||
return {"translated_text": content, "provider": self.name}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"] + f"\nTone: {tone}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
ctx = ""
|
||||
if context:
|
||||
ctx = "\n".join(f"{k}: {v}" for k, v in context.items() if v)
|
||||
prompt = f"{ctx}\nCustomer inquiry:\n{inquiry}\n\nWrite a reply:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"] + f"\nStyle: {style}\nAudience: {target}\nLanguage: {language}"
|
||||
if preference_context:
|
||||
system += f"\nUser preference: {preference_context}"
|
||||
info = json.dumps(product_info, ensure_ascii=False)
|
||||
prompt = f"Product:\n{info}\n\nGenerate marketing copy:"
|
||||
content = await self._call(system, prompt, max_tokens=1500)
|
||||
return {"content": content, "provider": self.name}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
prompt = f"Schema:\n{json.dumps(schema, indent=2)}\n\nText:\n{text}\n\nJSON:"
|
||||
content = await self._call(system, prompt, response_format={"type": "json_object"})
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000, response_format: Optional[Dict] = None) -> str:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
resp = await self.client.chat.completions.create(**kwargs)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return 0.0
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.base import AIProvider
|
||||
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider
|
||||
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider, SparkProvider, SensenovaProvider
|
||||
from app.config import settings
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
import logging
|
||||
@@ -23,6 +23,17 @@ class AIRouter:
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI init failed: {e}")
|
||||
|
||||
if settings.SENSENOVA_API_KEY:
|
||||
try:
|
||||
self.providers["sensenova"] = SensenovaProvider(
|
||||
api_key=settings.SENSENOVA_API_KEY,
|
||||
model=settings.SENSENOVA_MODEL,
|
||||
base_url=settings.SENSENOVA_BASE_URL,
|
||||
)
|
||||
logger.info("Sensenova provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Sensenova init failed: {e}")
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
try:
|
||||
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
|
||||
@@ -37,6 +48,17 @@ class AIRouter:
|
||||
except Exception as e:
|
||||
logger.warning(f"DeepL init failed: {e}")
|
||||
|
||||
if settings.IFLYTEK_API_KEY:
|
||||
try:
|
||||
self.providers["spark"] = SparkProvider(
|
||||
api_key=settings.IFLYTEK_API_KEY,
|
||||
model=settings.IFLYTEK_MODEL,
|
||||
base_url=settings.IFLYTEK_API_BASE,
|
||||
)
|
||||
logger.info("Spark provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Spark init failed: {e}")
|
||||
|
||||
if settings.LOCAL_MODEL_ENABLED:
|
||||
try:
|
||||
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
|
||||
@@ -90,11 +112,11 @@ class AIRouter:
|
||||
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
return await self.execute("reply", "reply", inquiry, context, tone)
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
return await self.execute("reply", "reply", inquiry, context, tone, preference_context)
|
||||
|
||||
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
return await self.execute("marketing", "generate_marketing", product_info, target, style, language)
|
||||
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en", preference_context: Optional[str] = None) -> Dict[str, Any]:
|
||||
return await self.execute("marketing", "generate_marketing", product_info, target, style, language, preference_context)
|
||||
|
||||
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await self.execute("extract", "extract_info", text, schema)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.admin import AdminService
|
||||
from app.api.v1.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def require_admin(current_user: dict = Depends(get_current_user)) -> dict:
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/dashboard")
|
||||
async def get_dashboard(
|
||||
_: dict = Depends(require_admin),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AdminService(db)
|
||||
return await service.get_dashboard()
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
_: dict = Depends(require_admin),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AdminService(db)
|
||||
return await service.list_users(page, size)
|
||||
|
||||
|
||||
@router.patch("/users/{target_user_id}/tier")
|
||||
async def update_user_tier(
|
||||
target_user_id: str,
|
||||
data: dict,
|
||||
_: dict = Depends(require_admin),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AdminService(db)
|
||||
tier = data.get("tier")
|
||||
if tier not in ("free", "pro", "enterprise"):
|
||||
raise HTTPException(status_code=400, detail="Invalid tier")
|
||||
success = await service.update_user_tier(target_user_id, tier)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return {"message": f"User tier updated to {tier}"}
|
||||
|
||||
|
||||
@router.post("/users/{target_user_id}/toggle-active")
|
||||
async def toggle_user_active(
|
||||
target_user_id: str,
|
||||
_: dict = Depends(require_admin),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AdminService(db)
|
||||
success = await service.toggle_user_active(target_user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return {"message": "User active status toggled"}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def system_health(
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AdminService(db)
|
||||
return await service.get_system_health()
|
||||
@@ -0,0 +1,73 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.analytics import AnalyticsService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/customers")
|
||||
async def customer_analytics(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
return await service.get_customer_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/translations")
|
||||
async def translation_analytics(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
return await service.get_translation_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/quotations")
|
||||
async def quotation_analytics(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
return await service.get_quotation_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/messages")
|
||||
async def message_analytics(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
return await service.get_message_stats(user_id)
|
||||
|
||||
|
||||
@router.get("/overview")
|
||||
async def overview(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
customers = await service.get_customer_stats(user_id)
|
||||
translations = await service.get_translation_stats(user_id)
|
||||
quotations = await service.get_quotation_stats(user_id)
|
||||
messages = await service.get_message_stats(user_id)
|
||||
marketing = await service.get_marketing_stats(user_id)
|
||||
return {
|
||||
"customers": customers,
|
||||
"translations": translations,
|
||||
"quotations": quotations,
|
||||
"messages": messages,
|
||||
"marketing": marketing,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/marketing")
|
||||
async def marketing_analytics(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = AnalyticsService(db)
|
||||
return await service.get_marketing_stats(user_id)
|
||||
@@ -1,13 +1,14 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Optional
|
||||
import uuid
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -30,7 +31,7 @@ class RefreshRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(data: RegisterRequest, db: Annotated[AsyncSession, Depends(get_db)]):
|
||||
async def register(data: RegisterRequest, db: Annotated[AsyncSession, Depends(get_db)] = None):
|
||||
existing = await db.execute(select(User).where(User.phone == data.phone))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Phone already registered")
|
||||
@@ -49,13 +50,14 @@ async def register(data: RegisterRequest, db: Annotated[AsyncSession, Depends(ge
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
form: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
result = await db.execute(select(User).where(User.phone == form.username))
|
||||
user = result.scalar_one_or_none()
|
||||
@@ -67,7 +69,7 @@ async def login(
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier}),
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
user={
|
||||
"id": str(user.id),
|
||||
@@ -78,6 +80,29 @@ async def login(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/guest")
|
||||
async def guest_login():
|
||||
guest_id = f"guest_{uuid.uuid4().hex[:12]}"
|
||||
access_token = create_access_token(
|
||||
{"sub": guest_id, "tier": "guest", "role": "guest", "is_guest": True},
|
||||
expires_delta=timedelta(hours=24)
|
||||
)
|
||||
refresh_token = create_refresh_token({"sub": guest_id, "is_guest": True})
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
user={
|
||||
"id": guest_id,
|
||||
"phone": None,
|
||||
"username": "游客用户",
|
||||
"tier": "guest",
|
||||
"is_guest": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh(data: RefreshRequest):
|
||||
payload = decode_token(data.refresh_token)
|
||||
@@ -92,7 +117,7 @@ async def refresh(data: RefreshRequest):
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(
|
||||
authorization: str = None,
|
||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
@@ -112,6 +137,7 @@ async def get_me(
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
"role": user.role,
|
||||
"settings": user.settings,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
}
|
||||
@@ -124,10 +150,50 @@ class SettingsUpdate(BaseModel):
|
||||
languages: list = None
|
||||
|
||||
|
||||
class WeChatLoginRequest(BaseModel):
|
||||
code: str
|
||||
encrypted_data: str = ""
|
||||
iv: str = ""
|
||||
|
||||
|
||||
@router.post("/wechat-login")
|
||||
async def wechat_login(data: WeChatLoginRequest, db: Annotated[AsyncSession, Depends(get_db)] = None):
|
||||
from app.services.wechat import wechat_service
|
||||
|
||||
session = await wechat_service.code2session(data.code)
|
||||
if not session:
|
||||
raise HTTPException(status_code=400, detail="WeChat login failed")
|
||||
|
||||
openid = session.get("openid")
|
||||
result = await db.execute(select(User).where(User.wechat_openid == openid))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
user = User(
|
||||
wechat_openid=openid,
|
||||
username=f"wx_{openid[-8:]}",
|
||||
tier="free",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
return LoginResponse(
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier, "role": user.role}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
user={
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
async def update_settings(
|
||||
data: SettingsUpdate,
|
||||
authorization: str = None,
|
||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Optional, List
|
||||
from app.database import get_db
|
||||
from app.services.customer import CustomerService
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
from app.services.import_service import import_service
|
||||
from app.services import export
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
@@ -87,6 +90,95 @@ async def delete_customer(
|
||||
return {"message": "Customer deleted"}
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
async def import_customers(
|
||||
file: UploadFile = File(...),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
from app.workers.tasks import process_customer_import
|
||||
|
||||
content = await file.read()
|
||||
filename = file.filename or ""
|
||||
|
||||
if filename.endswith(".xlsx"):
|
||||
records, parse_errors = import_service.parse_xlsx(content)
|
||||
elif filename.endswith(".csv"):
|
||||
records, parse_errors = import_service.parse_csv(content)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported file format. Use .xlsx or .csv")
|
||||
|
||||
if parse_errors and not records:
|
||||
raise HTTPException(status_code=400, detail=f"Parse failed: {'; '.join(parse_errors)}")
|
||||
|
||||
valid, validation_errors = import_service.validate_records(records)
|
||||
all_errors = parse_errors + validation_errors
|
||||
imported_count = 0
|
||||
|
||||
for record in valid:
|
||||
try:
|
||||
svc = CustomerService(db)
|
||||
await svc.create_customer(user_id, record)
|
||||
imported_count += 1
|
||||
except Exception as e:
|
||||
all_errors.append(f"Import failed for {record.get('name', 'unknown')}: {str(e)}")
|
||||
|
||||
return {
|
||||
"imported": imported_count,
|
||||
"total": len(records),
|
||||
"errors": all_errors,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_customers(
|
||||
status: Optional[str] = None,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
result = await service.list_customers(user_id, status, 1, 9999)
|
||||
items = result.get("items", [])
|
||||
csv_bytes = export.export_customers_csv(items)
|
||||
return Response(
|
||||
content=csv_bytes,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=customers.csv"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health-overview")
|
||||
async def get_health_overview(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerHealthService(db)
|
||||
return await service.get_health_overview(user_id)
|
||||
|
||||
|
||||
@router.get("/health-scores")
|
||||
async def get_all_health_scores(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerHealthService(db)
|
||||
return {"items": await service.get_all_health_scores(user_id)}
|
||||
|
||||
|
||||
@router.get("/{customer_id}/health")
|
||||
async def get_customer_health(
|
||||
customer_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerHealthService(db)
|
||||
health = await service.get_customer_health(user_id, customer_id)
|
||||
if not health:
|
||||
raise HTTPException(status_code=404, detail="Customer not found")
|
||||
return health
|
||||
|
||||
|
||||
@router.get("/{customer_id}/conversation")
|
||||
async def get_conversation(
|
||||
customer_id: str,
|
||||
|
||||
@@ -1,13 +1,43 @@
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi import HTTPException, Depends, Header
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.core.security import decode_token
|
||||
from typing import Optional
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user_id(authorization: str = None) -> str:
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
async def get_current_user_id(
|
||||
authorization: Optional[str] = Header(None, alias="Authorization"),
|
||||
cred: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
) -> str:
|
||||
token = None
|
||||
if cred:
|
||||
token = cred.credentials
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization[7:]
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Missing or invalid token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
payload = decode_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload.get("sub")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
cred: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> dict:
|
||||
if not cred:
|
||||
raise HTTPException(status_code=401, detail="Missing or invalid token")
|
||||
|
||||
payload = decode_token(cred.credentials)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return {
|
||||
"id": payload.get("sub"),
|
||||
"tier": payload.get("tier", "free"),
|
||||
"role": payload.get("role", "user"),
|
||||
}
|
||||
|
||||
@@ -1,26 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from app.services.exchange import ExchangeRateService
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ExchangeRateResponse(BaseModel):
|
||||
from_currency: str
|
||||
to_currency: str
|
||||
rate: float
|
||||
updated_at: str
|
||||
|
||||
|
||||
EXCHANGE_RATES = {
|
||||
("USD", "CNY"): 7.24,
|
||||
("EUR", "CNY"): 7.85,
|
||||
("GBP", "CNY"): 9.15,
|
||||
("CNY", "USD"): 0.138,
|
||||
("USD", "EUR"): 0.92,
|
||||
("EUR", "USD"): 1.09,
|
||||
("GBP", "USD"): 1.27,
|
||||
("USD", "GBP"): 0.79,
|
||||
}
|
||||
service = ExchangeRateService()
|
||||
|
||||
|
||||
@router.get("/convert")
|
||||
@@ -29,26 +12,25 @@ async def convert_currency(
|
||||
to_currency: str = "CNY",
|
||||
amount: float = 1.0,
|
||||
):
|
||||
rate = EXCHANGE_RATES.get((from_currency, to_currency), 1.0)
|
||||
rate = await service.get_rate(from_currency, to_currency)
|
||||
if rate is None:
|
||||
return {"error": f"No rate available for {from_currency} -> {to_currency}"}
|
||||
|
||||
return {
|
||||
"from_currency": from_currency,
|
||||
"to_currency": to_currency,
|
||||
"from_currency": from_currency.upper(),
|
||||
"to_currency": to_currency.upper(),
|
||||
"amount": amount,
|
||||
"converted": round(amount * rate, 2),
|
||||
"rate": rate,
|
||||
"updated_at": "2026-05-08T00:00:00Z",
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/rates")
|
||||
async def get_rates(base: str = "USD"):
|
||||
rates = {}
|
||||
for (from_curr, to_curr), rate in EXCHANGE_RATES.items():
|
||||
if from_curr == base:
|
||||
rates[to_curr] = rate
|
||||
|
||||
rates = await service.get_all_rates(base)
|
||||
return {
|
||||
"base": base,
|
||||
"base": base.upper(),
|
||||
"rates": rates,
|
||||
"updated_at": "2026-05-08T00:00:00Z",
|
||||
}
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.models.feedback import Feedback
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class FeedbackRequest(BaseModel):
|
||||
category: str = "general"
|
||||
content: str
|
||||
contact: str = ""
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def submit_feedback(
|
||||
data: FeedbackRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not data.content.strip():
|
||||
raise HTTPException(status_code=400, detail="Content is required")
|
||||
|
||||
fb = Feedback(
|
||||
user_id=user_id,
|
||||
category=data.category,
|
||||
content=data.content.strip(),
|
||||
contact=data.contact.strip(),
|
||||
)
|
||||
db.add(fb)
|
||||
await db.flush()
|
||||
return {"status": "ok", "id": str(fb.id)}
|
||||
@@ -0,0 +1,89 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from app.database import get_db
|
||||
from app.services.followup_engine import FollowupEngine
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/strategies")
|
||||
async def list_strategies(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
await engine.ensure_default_strategies()
|
||||
return {"strategies": await engine.get_strategies()}
|
||||
|
||||
|
||||
@router.get("/pending")
|
||||
async def get_pending_followups(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
return await engine.get_pending_followups(user_id, page, size)
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def get_followup_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
return await engine.get_followup_logs(user_id, page, size)
|
||||
|
||||
|
||||
@router.post("/{log_id}/send")
|
||||
async def mark_followup_sent(
|
||||
log_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
success = await engine.mark_sent(user_id, log_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Followup log not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/{log_id}/edit")
|
||||
async def edit_and_send_followup(
|
||||
log_id: str,
|
||||
body: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
edited_text = body.get("edited_text", "")
|
||||
if not edited_text:
|
||||
raise HTTPException(status_code=400, detail="edited_text is required")
|
||||
engine = FollowupEngine(db)
|
||||
success = await engine.mark_edited(user_id, log_id, edited_text)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Followup log not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_followup_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
return await engine.get_stats(user_id)
|
||||
|
||||
|
||||
@router.post("/scan")
|
||||
async def trigger_followup_scan(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
engine = FollowupEngine(db)
|
||||
result = await engine.scan_and_followup()
|
||||
return result
|
||||
@@ -0,0 +1,102 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.preference import UserPreferenceService
|
||||
from app.services.marketing_effect import MarketingEffectService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/select")
|
||||
async def record_selection(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
message_id = data.get("message_id")
|
||||
selected_index = data.get("selected_index")
|
||||
if message_id is None or selected_index is None:
|
||||
raise HTTPException(status_code=400, detail="message_id and selected_index required")
|
||||
service = UserPreferenceService(db)
|
||||
success = await service.record_selection(user_id, message_id, selected_index)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/edit")
|
||||
async def record_edit(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
message_id = data.get("message_id")
|
||||
edited_text = data.get("edited_text")
|
||||
if not message_id or edited_text is None:
|
||||
raise HTTPException(status_code=400, detail="message_id and edited_text required")
|
||||
service = UserPreferenceService(db)
|
||||
success = await service.record_edit(user_id, message_id, edited_text)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_preferences(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = UserPreferenceService(db)
|
||||
preferences = await service.analyze_preferences(user_id)
|
||||
return preferences
|
||||
|
||||
|
||||
@router.get("/preferences")
|
||||
async def get_preferences(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = UserPreferenceService(db)
|
||||
return await service.get_analysis(user_id)
|
||||
|
||||
|
||||
@router.post("/marketing-effect")
|
||||
async def track_marketing_effect(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = MarketingEffectService(db)
|
||||
result = await service.track_event(
|
||||
user_id=user_id,
|
||||
content=data.get("content", ""),
|
||||
product_id=data.get("product_id"),
|
||||
product_name=data.get("product_name"),
|
||||
channel=data.get("channel", "copy"),
|
||||
event_type=data.get("event_type", "copy"),
|
||||
target_audience=data.get("target_audience", ""),
|
||||
metadata=data.get("metadata"),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/marketing-effects")
|
||||
async def get_marketing_effects(
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = MarketingEffectService(db)
|
||||
return await service.get_effects(user_id, page, size)
|
||||
|
||||
|
||||
@router.get("/marketing-effects/stats")
|
||||
async def get_marketing_effect_stats(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = MarketingEffectService(db)
|
||||
return await service.get_stats(user_id)
|
||||
@@ -1,8 +1,12 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Optional, Annotated
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.marketing import MarketingService
|
||||
from app.services.preference import UserPreferenceService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
@@ -36,11 +40,15 @@ class CompetitorRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_marketing(data: MarketingRequest, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
async def generate_marketing(
|
||||
data: MarketingRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = MarketingService()
|
||||
pref_service = UserPreferenceService(db)
|
||||
pref_context = await pref_service.get_preference_context(user_id, "marketing")
|
||||
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
"description": data.description,
|
||||
@@ -48,7 +56,7 @@ async def generate_marketing(data: MarketingRequest, authorization: str = None):
|
||||
"price": data.price,
|
||||
"keywords": data.keywords,
|
||||
}
|
||||
results = await service.generate(product_info, data.target, data.style, data.language, data.count)
|
||||
results = await service.generate(product_info, data.target, data.style, data.language, data.count, pref_context)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from app.database import get_db
|
||||
from app.services.notification import NotificationService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_notifications(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
unread_only: bool = Query(False),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = NotificationService(db)
|
||||
return await service.list_notifications(user_id, page, size, unread_only)
|
||||
|
||||
|
||||
@router.get("/unread-count")
|
||||
async def unread_count(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = NotificationService(db)
|
||||
count = await service.get_unread_count(user_id)
|
||||
return {"count": count}
|
||||
|
||||
|
||||
@router.patch("/{notification_id}/read")
|
||||
async def mark_read(
|
||||
notification_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = NotificationService(db)
|
||||
success = await service.mark_read(user_id, notification_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/read-all")
|
||||
async def mark_all_read(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = NotificationService(db)
|
||||
count = await service.mark_all_read(user_id)
|
||||
return {"status": "ok", "count": count}
|
||||
|
||||
|
||||
@router.delete("/{notification_id}")
|
||||
async def delete_notification(
|
||||
notification_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = NotificationService(db)
|
||||
success = await service.delete_notification(user_id, notification_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
return {"status": "ok"}
|
||||
@@ -0,0 +1,43 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.services.onboarding import OnboardingService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class OnboardingRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
category: str = ""
|
||||
target: str = "US importers"
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_status(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = OnboardingService(db)
|
||||
return await service.check_status(user_id)
|
||||
|
||||
|
||||
@router.post("/product")
|
||||
async def create_first_product(
|
||||
data: OnboardingRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not data.name.strip():
|
||||
raise HTTPException(status_code=400, detail="Product name is required")
|
||||
service = OnboardingService(db)
|
||||
return await service.generate_first_product(
|
||||
user_id=user_id,
|
||||
name=data.name.strip(),
|
||||
description=data.description.strip(),
|
||||
category=data.category,
|
||||
target=data.target,
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.services.payment import PaymentService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateOrderRequest(BaseModel):
|
||||
plan: str
|
||||
|
||||
|
||||
class PaymentCallbackRequest(BaseModel):
|
||||
payment_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
@router.get("/plans")
|
||||
async def get_plans():
|
||||
svc = PaymentService(None)
|
||||
return await svc.get_plans()
|
||||
|
||||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
return await svc.get_current_subscription(user_id)
|
||||
|
||||
|
||||
@router.post("/create-order")
|
||||
async def create_order(
|
||||
data: CreateOrderRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
try:
|
||||
return await svc.create_order(user_id, data.plan)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def payment_callback(
|
||||
data: PaymentCallbackRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
svc = PaymentService(db)
|
||||
success = await svc.handle_payment_callback(data.payment_id, data.success)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Order not found")
|
||||
return {"status": "ok"}
|
||||
+41
-125
@@ -1,147 +1,63 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional, List
|
||||
from typing import Annotated, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import decode_token
|
||||
from app.services.push import PushService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class DeviceRegister(BaseModel):
|
||||
class DeviceRegisterRequest(BaseModel):
|
||||
client_id: str
|
||||
platform: Optional[str] = None
|
||||
platform: str = "weapp"
|
||||
push_token: Optional[str] = None
|
||||
device_info: Optional[dict] = None
|
||||
|
||||
|
||||
class PushMessage(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
payload: Optional[dict] = None
|
||||
target_type: str = "all"
|
||||
target_value: Optional[str] = None
|
||||
|
||||
|
||||
class PushResponse(BaseModel):
|
||||
success: bool
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# 模拟存储的设备信息(实际应存数据库)
|
||||
devices_db = {}
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register_device(
|
||||
data: DeviceRegister,
|
||||
authorization: str = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
data: DeviceRegisterRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if user_id not in devices_db:
|
||||
devices_db[user_id] = []
|
||||
|
||||
existing = [d for d in devices_db[user_id] if d.get("client_id") == data.client_id]
|
||||
if not existing:
|
||||
devices_db[user_id].append({
|
||||
"client_id": data.client_id,
|
||||
"platform": data.platform,
|
||||
"device_info": data.device_info,
|
||||
})
|
||||
|
||||
return {"success": True, "message": "Device registered"}
|
||||
|
||||
|
||||
@router.post("/send")
|
||||
async def send_push(
|
||||
message: PushMessage,
|
||||
authorization: str = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
user_devices = devices_db.get(user_id, [])
|
||||
if not user_devices:
|
||||
return PushResponse(success=False, error="No devices registered")
|
||||
|
||||
# 实际项目中这里调用 uni-push/极光等API
|
||||
# 模拟返回成功
|
||||
message_id = f"msg_{user_id}_{int(payload.get('iat', 0))}"
|
||||
|
||||
print(f"Push message to user {user_id}: {message.title} - {message.content}")
|
||||
|
||||
return PushResponse(success=True, message_id=message_id)
|
||||
|
||||
|
||||
@router.post("/send-to-customer")
|
||||
async def send_to_customer(
|
||||
customer_id: str,
|
||||
title: str,
|
||||
content: str,
|
||||
payload: Optional[dict] = None,
|
||||
authorization: str = None,
|
||||
):
|
||||
"""
|
||||
针对特定客户的推送通知
|
||||
例如:客户沉默提醒、报价提醒等
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload_data = decode_token(authorization[7:])
|
||||
if not payload_data:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload_data.get("sub")
|
||||
|
||||
# 这里可以添加针对客户的特定逻辑
|
||||
notification = {
|
||||
"type": "customer_alert",
|
||||
"customer_id": customer_id,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"payload": payload or {}
|
||||
service = PushService(db)
|
||||
device = await service.register_device(
|
||||
user_id=user_id,
|
||||
client_id=data.client_id,
|
||||
platform=data.platform,
|
||||
push_token=data.push_token,
|
||||
device_info=data.device_info,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"device_id": str(device.id),
|
||||
"message": "Device registered",
|
||||
}
|
||||
|
||||
print(f"Customer notification for user {user_id}, customer {customer_id}: {title}")
|
||||
|
||||
return PushResponse(success=True, message_id=f"alert_{customer_id}")
|
||||
@router.post("/unregister")
|
||||
async def unregister_device(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
client_id = data.get("client_id")
|
||||
if not client_id:
|
||||
raise HTTPException(status_code=400, detail="client_id required")
|
||||
service = PushService(db)
|
||||
success = await service.unregister_device(user_id, client_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Device not found")
|
||||
return {"success": True, "message": "Device unregistered"}
|
||||
|
||||
|
||||
@router.get("/devices")
|
||||
async def list_devices(
|
||||
authorization: str = None,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
"""列出用户已注册的设备"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user_devices = devices_db.get(user_id, [])
|
||||
|
||||
return {
|
||||
"devices": user_devices,
|
||||
"count": len(user_devices)
|
||||
}
|
||||
service = PushService(db)
|
||||
devices = await service.get_user_devices(user_id)
|
||||
return {"devices": devices, "count": len(devices)}
|
||||
|
||||
@@ -1,13 +1,39 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.services.quotation import QuotationService
|
||||
from app.services.pdf_generator import pdf_generator
|
||||
from app.services import export
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.models.quotation import Quotation
|
||||
from app.models.customer import Customer
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class InquiryRequest(BaseModel):
|
||||
inquiry_text: str
|
||||
customer_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/generate-from-inquiry")
|
||||
async def generate_from_inquiry(
|
||||
data: InquiryRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
result = await service.generate_from_inquiry(
|
||||
user_id=user_id,
|
||||
inquiry_text=data.inquiry_text,
|
||||
customer_id=data.customer_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_quotation(
|
||||
data: dict,
|
||||
@@ -58,3 +84,78 @@ async def update_quotation_status(
|
||||
if not quotation:
|
||||
raise HTTPException(status_code=404, detail="Quotation not found")
|
||||
return quotation
|
||||
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_quotations(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
result = await service.list_quotations(user_id, 1, 9999)
|
||||
items = result.get("items", [])
|
||||
csv_bytes = export.export_quotations_csv(items)
|
||||
return Response(
|
||||
content=csv_bytes,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment; filename=quotations.csv"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{quotation_id}/pdf")
|
||||
async def export_quotation_pdf(
|
||||
quotation_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
quotation = await service.get_quotation(user_id, quotation_id)
|
||||
if not quotation:
|
||||
raise HTTPException(status_code=404, detail="Quotation not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(Customer).where(Customer.id == quotation["customer_id"])
|
||||
)
|
||||
customer = result.scalar_one_or_none()
|
||||
|
||||
pdf_data = pdf_generator.generate_quotation({
|
||||
"quotation_number": f"{quotation_id[:8].upper()}",
|
||||
"customer_name": customer.name if customer else "",
|
||||
"customer_company": customer.company if customer else "",
|
||||
"customer_country": customer.country if customer else "",
|
||||
"date": quotation["created_at"][:10] if quotation.get("created_at") else "",
|
||||
"valid_until": quotation.get("valid_until", ""),
|
||||
"currency": quotation.get("currency", "USD"),
|
||||
"items": quotation.get("items", []),
|
||||
"subtotal": quotation.get("subtotal", 0),
|
||||
"discount": quotation.get("discount", 0),
|
||||
"shipping": quotation.get("shipping", 0),
|
||||
"total": quotation.get("total", 0),
|
||||
"payment_terms": quotation.get("payment_terms", ""),
|
||||
"delivery_terms": quotation.get("delivery_terms", ""),
|
||||
"lead_time": quotation.get("lead_time", ""),
|
||||
"notes": quotation.get("notes", ""),
|
||||
})
|
||||
|
||||
if not pdf_data:
|
||||
raise HTTPException(status_code=501, detail="PDF generation not available (weasyprint not installed)")
|
||||
|
||||
service = QuotationService(db)
|
||||
result = await db.execute(
|
||||
select(Quotation).where(
|
||||
and_(Quotation.id == quotation_id, Quotation.user_id == user_id)
|
||||
)
|
||||
)
|
||||
q = result.scalar_one_or_none()
|
||||
if q:
|
||||
pdf_url = f"/quotations/{quotation_id}/pdf"
|
||||
q.pdf_url = pdf_url
|
||||
await db.flush()
|
||||
|
||||
return Response(
|
||||
content=pdf_data,
|
||||
media_type="application/pdf",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="quotation-{quotation_id[:8]}.pdf"',
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.silent_pattern import SilentPatternService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/risk-analysis")
|
||||
async def get_silent_risk_analysis(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = SilentPatternService(db)
|
||||
risks = await service.analyze_silent_risk(user_id)
|
||||
return {
|
||||
"items": risks,
|
||||
"total": len(risks),
|
||||
"high_risk": len([r for r in risks if r["risk_level"] == "high"]),
|
||||
"medium_risk": len([r for r in risks if r["risk_level"] == "medium"]),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{customer_id}/suggestions")
|
||||
async def get_followup_suggestions(
|
||||
customer_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = SilentPatternService(db)
|
||||
suggestions = await service.get_suggestions(user_id, customer_id)
|
||||
return {"customer_id": customer_id, "suggestions": suggestions}
|
||||
@@ -0,0 +1,117 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.services.team import TeamService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateTeamRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class InviteRequest(BaseModel):
|
||||
user_id: str
|
||||
|
||||
|
||||
class UpdateRoleRequest(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_team(
|
||||
data: CreateTeamRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
try:
|
||||
team = await service.create_team(user_id, data.name, data.description)
|
||||
return team
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_teams(
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
return {"teams": await service.list_user_teams(user_id)}
|
||||
|
||||
|
||||
@router.get("/{team_id}")
|
||||
async def get_team(
|
||||
team_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
team = await service.get_team(team_id, user_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
return team
|
||||
|
||||
|
||||
@router.post("/{team_id}/invite")
|
||||
async def invite_member(
|
||||
team_id: str,
|
||||
data: InviteRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
try:
|
||||
result = await service.invite_member(team_id, user_id, data.user_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{team_id}/members/{member_id}")
|
||||
async def remove_member(
|
||||
team_id: str,
|
||||
member_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
success = await service.remove_member(team_id, user_id, member_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Member not found or not removable")
|
||||
return {"message": "Member removed"}
|
||||
|
||||
|
||||
@router.post("/{team_id}/leave")
|
||||
async def leave_team(
|
||||
team_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
success = await service.leave_team(team_id, user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Cannot leave as owner or not a member")
|
||||
return {"message": "Left team"}
|
||||
|
||||
|
||||
@router.patch("/{team_id}/members/{member_id}/role")
|
||||
async def update_member_role(
|
||||
team_id: str,
|
||||
member_id: str,
|
||||
data: UpdateRoleRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = TeamService(db)
|
||||
if data.role not in ("admin", "member", "viewer"):
|
||||
raise HTTPException(status_code=400, detail="Invalid role")
|
||||
success = await service.update_role(team_id, user_id, member_id, data.role)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Member not found or not updatable")
|
||||
return {"message": "Role updated"}
|
||||
@@ -0,0 +1,44 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.corpus_trainer import CorpusTrainer
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/corpus/run")
|
||||
async def run_corpus_training(
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
trainer = CorpusTrainer(db)
|
||||
result = await trainer.run_pipeline()
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/corpus/embeddings")
|
||||
async def compute_embeddings(
|
||||
batch_size: int = 50,
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
trainer = CorpusTrainer(db)
|
||||
result = await trainer.compute_embeddings(batch_size)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/corpus/stats")
|
||||
async def corpus_stats(
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
trainer = CorpusTrainer(db)
|
||||
return await trainer.get_stats()
|
||||
|
||||
|
||||
@router.post("/corpus/deduplicate")
|
||||
async def deduplicate_corpus(
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
trainer = CorpusTrainer(db)
|
||||
result = await trainer.deduplicate()
|
||||
return result
|
||||
@@ -1,8 +1,13 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Response, Depends
|
||||
from typing import Optional, Dict, Any, Annotated
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
from app.services.translation import TranslationService
|
||||
from app.services.tts import tts_service
|
||||
from app.services.preference import UserPreferenceService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -27,13 +32,10 @@ class ExtractRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def translate_text(data: TranslateRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
user_id = payload.get("sub") if payload else None
|
||||
|
||||
async def translate_text(
|
||||
data: TranslateRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
service = TranslationService()
|
||||
result = await service.translate(
|
||||
text=data.text,
|
||||
@@ -46,9 +48,13 @@ async def translate_text(data: TranslateRequest, authorization: str = None):
|
||||
|
||||
|
||||
@router.post("/reply")
|
||||
async def generate_reply(data: ReplyRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
async def generate_reply(
|
||||
data: ReplyRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
pref_service = UserPreferenceService(db)
|
||||
pref_context = await pref_service.get_preference_context(user_id, "reply")
|
||||
|
||||
service = TranslationService()
|
||||
results = await service.generate_reply(
|
||||
@@ -56,25 +62,65 @@ async def generate_reply(data: ReplyRequest, authorization: str = None):
|
||||
context=data.context,
|
||||
tone=data.tone,
|
||||
count=data.count,
|
||||
preference_context=pref_context,
|
||||
)
|
||||
return {"suggestions": results, "inquiry": data.inquiry, "count": len(results)}
|
||||
|
||||
|
||||
@router.post("/extract")
|
||||
async def extract_info(data: ExtractRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
async def extract_info(
|
||||
data: ExtractRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
service = TranslationService()
|
||||
result = await service.extract_info(data.text, data.extract_type)
|
||||
return {"extracted": result, "type": data.extract_type}
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def feedback(data: dict, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
lang: str = "en"
|
||||
rate: str = "0%"
|
||||
pitch: str = "0Hz"
|
||||
|
||||
|
||||
@router.post("/tts")
|
||||
async def text_to_speech(
|
||||
data: TTSRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
audio = await tts_service.synthesize(data.text, data.lang, data.rate, data.pitch)
|
||||
if not audio:
|
||||
raise HTTPException(status_code=501, detail="TTS not available (edge-tts not installed or synthesis failed)")
|
||||
|
||||
return Response(content=audio, media_type="audio/mpeg", headers={
|
||||
"Content-Disposition": f'attachment; filename="tts-{data.lang}.mp3"',
|
||||
})
|
||||
|
||||
|
||||
@router.get("/tts")
|
||||
async def text_to_speech_get(
|
||||
text: str = "",
|
||||
lang: str = "en",
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
if not text.strip():
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
audio = await tts_service.synthesize(text, lang)
|
||||
if not audio:
|
||||
raise HTTPException(status_code=501, detail="TTS not available")
|
||||
|
||||
return Response(content=audio, media_type="audio/mpeg", headers={
|
||||
"Content-Disposition": f'attachment; filename="tts-{lang}.mp3"',
|
||||
})
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def feedback(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
corpus = TradeCorpus()
|
||||
|
||||
@@ -84,3 +130,26 @@ async def feedback(data: dict, authorization: str = None):
|
||||
await corpus.rate_entry(entry_id, rating)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
public_router = APIRouter(tags=["translate-public"])
|
||||
|
||||
|
||||
@public_router.post("/translate")
|
||||
async def public_translate(data: TranslateRequest):
|
||||
service = TranslationService()
|
||||
result = await service.translate(
|
||||
text=data.text,
|
||||
target_lang=data.target_lang,
|
||||
source_lang=data.source_lang,
|
||||
context=data.context,
|
||||
user_id=None,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@public_router.post("/extract")
|
||||
async def public_extract(data: ExtractRequest):
|
||||
service = TranslationService()
|
||||
result = await service.extract_info(data.text, data.extract_type)
|
||||
return {"extracted": result, "type": data.extract_type}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from sqlalchemy import select, and_
|
||||
from typing import Annotated, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.services.whatsapp import WhatsAppService
|
||||
from app.services.customer import CustomerService
|
||||
from app.services.translation import TranslationService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.config import settings
|
||||
from app.models.customer import Customer
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -26,35 +27,92 @@ async def verify_webhook(
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def handle_webhook(request: Request, db: Annotated[AsyncSession, Depends(get_db)] = None):
|
||||
async def handle_webhook(
|
||||
request: Request,
|
||||
x_hub_signature_256: Optional[str] = Header(None),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
svc = WhatsAppService()
|
||||
body = await request.json()
|
||||
body = await request.body()
|
||||
|
||||
msg_data = svc.parse_webhook(body)
|
||||
if x_hub_signature_256:
|
||||
if not svc.verify_signature(body, x_hub_signature_256):
|
||||
raise HTTPException(status_code=403, detail="Invalid signature")
|
||||
|
||||
import json
|
||||
body_json = json.loads(body)
|
||||
msg_data = svc.parse_webhook(body_json)
|
||||
if not msg_data:
|
||||
return {"status": "ok"}
|
||||
|
||||
# TODO: Route to correct user based on WhatsApp number
|
||||
# For MVP, handle as generic incoming message
|
||||
from_number = msg_data.get("from")
|
||||
text = msg_data.get("text", "")
|
||||
|
||||
if from_number:
|
||||
result = await db.execute(
|
||||
select(Customer).where(Customer.whatsapp_id == from_number)
|
||||
)
|
||||
customer = result.scalar_one_or_none()
|
||||
|
||||
if customer:
|
||||
user_id = str(customer.user_id)
|
||||
cust_svc = CustomerService(db)
|
||||
await cust_svc.save_message(
|
||||
user_id=user_id,
|
||||
customer_id=str(customer.id),
|
||||
direction="inbound",
|
||||
content=text,
|
||||
)
|
||||
|
||||
return {"status": "ok", "message": "received"}
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
to: str
|
||||
text: str = ""
|
||||
template_name: Optional[str] = None
|
||||
template_params: Optional[dict] = None
|
||||
media_url: Optional[str] = None
|
||||
media_type: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/send")
|
||||
async def send_message(
|
||||
data: dict,
|
||||
data: SendMessageRequest,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
text = data.get("text")
|
||||
to = data.get("to")
|
||||
if not text or not to:
|
||||
raise HTTPException(status_code=400, detail="text and to are required")
|
||||
|
||||
svc = WhatsAppService()
|
||||
sent = await svc.send_text(to, text)
|
||||
|
||||
sent = False
|
||||
if data.template_name and data.template_params:
|
||||
sent = await svc.send_template(data.to, data.template_name, data.template_params)
|
||||
elif data.media_url and data.media_type:
|
||||
sent = await svc.send_media(data.to, data.media_url, data.media_type, caption=data.text)
|
||||
elif data.text:
|
||||
sent = await svc.send_text(data.to, data.text)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="text, template, or media required")
|
||||
|
||||
if not sent:
|
||||
raise HTTPException(status_code=500, detail="Failed to send WhatsApp message")
|
||||
|
||||
return {"status": "sent", "to": to}
|
||||
cust_svc = CustomerService(db)
|
||||
result = await db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.whatsapp_id == data.to, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
customer = result.scalar_one_or_none()
|
||||
if customer:
|
||||
await cust_svc.save_message(
|
||||
user_id=user_id,
|
||||
customer_id=str(customer.id),
|
||||
direction="outbound",
|
||||
content=data.text or f"[{data.media_type or 'template'}]",
|
||||
)
|
||||
|
||||
return {"status": "sent", "to": data.to}
|
||||
|
||||
|
||||
@router.get("/qr")
|
||||
|
||||
@@ -20,4 +20,26 @@ celery_app.conf.update(
|
||||
task_time_limit=300,
|
||||
worker_prefetch_multiplier=4,
|
||||
worker_max_tasks_per_child=1000,
|
||||
beat_schedule={
|
||||
"check-silent-customers": {
|
||||
"task": "app.workers.tasks.check_silent_customers",
|
||||
"schedule": 3600.0,
|
||||
},
|
||||
"update-customer-health-cache": {
|
||||
"task": "app.workers.tasks.update_customer_health_cache",
|
||||
"schedule": 3600.0,
|
||||
},
|
||||
"cleanup-old-sessions": {
|
||||
"task": "app.workers.tasks.cleanup_old_sessions",
|
||||
"schedule": 86400.0,
|
||||
},
|
||||
"daily-corpus-training": {
|
||||
"task": "app.workers.tasks.run_daily_corpus_training",
|
||||
"schedule": 86400.0,
|
||||
},
|
||||
"check-followup-engine": {
|
||||
"task": "app.workers.tasks.check_followup_engine",
|
||||
"schedule": 21600.0,
|
||||
},
|
||||
},
|
||||
)
|
||||
+22
-7
@@ -1,4 +1,4 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseSettings
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,7 +8,10 @@ ENV_FILE = PROJECT_ROOT / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_file": str(ENV_FILE), "extra": "ignore"}
|
||||
class Config:
|
||||
env_file = str(ENV_FILE)
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
APP_NAME: str = "TradeMate"
|
||||
|
||||
@@ -29,6 +32,14 @@ class Settings(BaseSettings):
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
DEEPL_API_KEY: Optional[str] = None
|
||||
|
||||
SENSENOVA_API_KEY: Optional[str] = None
|
||||
SENSENOVA_BASE_URL: str = "https://token.sensenova.cn/v1"
|
||||
SENSENOVA_MODEL: str = "sensenova-6.7-flash-lite"
|
||||
|
||||
IFLYTEK_API_KEY: Optional[str] = None
|
||||
IFLYTEK_API_BASE: str = "https://maas-api.cn-huabei-1.xf-yun.com/v2"
|
||||
IFLYTEK_MODEL: str = "astron-code-latest"
|
||||
|
||||
LOCAL_MODEL_ENABLED: bool = False
|
||||
LOCAL_MODEL_URL: str = "http://localhost:8001"
|
||||
|
||||
@@ -38,6 +49,7 @@ class Settings(BaseSettings):
|
||||
|
||||
WECHAT_APP_ID: Optional[str] = None
|
||||
WECHAT_APP_SECRET: Optional[str] = None
|
||||
WECHAT_PUSH_TEMPLATE_ID: Optional[str] = None
|
||||
|
||||
EXCHANGE_RATE_API_KEY: Optional[str] = None
|
||||
|
||||
@@ -47,12 +59,15 @@ class Settings(BaseSettings):
|
||||
FRONTEND_URL: str = "http://localhost:3000"
|
||||
BACKEND_URL: str = "http://localhost:8000"
|
||||
|
||||
SENTRY_DSN: Optional[str] = None
|
||||
DEBUG: bool = True
|
||||
|
||||
AI_ROUTING: dict = {
|
||||
"translate": {"primary": "deepl", "fallback": ["openai", "local"]},
|
||||
"reply": {"primary": "openai", "fallback": ["anthropic", "local"]},
|
||||
"marketing": {"primary": "anthropic", "fallback": ["openai", "local"]},
|
||||
"extract": {"primary": "openai", "fallback": ["anthropic"]},
|
||||
"quotation": {"primary": "openai", "fallback": ["anthropic"]},
|
||||
"translate": {"primary": "sensenova", "fallback": ["openai", "local"]},
|
||||
"reply": {"primary": "sensenova", "fallback": ["anthropic", "local"]},
|
||||
"marketing": {"primary": "sensenova", "fallback": ["openai", "local"]},
|
||||
"extract": {"primary": "sensenova", "fallback": ["openai"]},
|
||||
"quotation": {"primary": "sensenova", "fallback": ["openai"]},
|
||||
}
|
||||
|
||||
FREE_DAILY_TRANSLATE_CHARS: int = 5000
|
||||
|
||||
@@ -1,26 +1,63 @@
|
||||
from fastapi import Request
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config import settings
|
||||
from app.core.security import decode_token
|
||||
import redis.asyncio as aioredis
|
||||
from redis.asyncio import ConnectionPool
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_pool = None
|
||||
|
||||
|
||||
async def get_redis():
|
||||
global _redis_pool
|
||||
if _redis_pool is None:
|
||||
_redis_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
||||
return aioredis.Redis(connection_pool=_redis_pool)
|
||||
|
||||
|
||||
def get_user_tier_from_token(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
request.state.user_id = None
|
||||
request.state.user_tier = "anonymous"
|
||||
return "anonymous"
|
||||
payload = decode_token(auth[7:])
|
||||
if not payload:
|
||||
request.state.user_id = None
|
||||
request.state.user_tier = "anonymous"
|
||||
return "anonymous"
|
||||
request.state.user_id = payload.get("sub")
|
||||
request.state.user_tier = payload.get("tier", "free")
|
||||
return request.state.user_tier
|
||||
|
||||
|
||||
RATE_LIMITS = {
|
||||
"free": 100,
|
||||
"pro": 500,
|
||||
"enterprise": 2000,
|
||||
}
|
||||
|
||||
|
||||
async def check_rate_limit(user_id: str, tier: str) -> int:
|
||||
r = await get_redis()
|
||||
now = time.time()
|
||||
window = 60
|
||||
key = f"ratelimit:{user_id}:{int(now // window)}"
|
||||
|
||||
count = await r.incr(key)
|
||||
if count == 1:
|
||||
await r.expire(key, window + 5)
|
||||
|
||||
limit = RATE_LIMITS.get(tier, 100)
|
||||
remaining = max(0, limit - count)
|
||||
return remaining
|
||||
|
||||
|
||||
class TierMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path.startswith("/api/v1"):
|
||||
@@ -49,16 +86,51 @@ class TierMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.url.path.startswith("/api/v1"):
|
||||
return await call_next(request)
|
||||
|
||||
user_tier = getattr(request.state, "user_tier", None)
|
||||
if user_tier in ("anonymous", None):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
remaining = await check_rate_limit(
|
||||
user_id, user_tier
|
||||
)
|
||||
if remaining == 0:
|
||||
return Response(
|
||||
status_code=429,
|
||||
content='{"error":"RATE_LIMITED","detail":"Too many requests, try again later"}',
|
||||
media_type="application/json",
|
||||
headers={"Retry-After": "60"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limit check failed: {e}")
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.url.path.startswith("/api/v1"):
|
||||
return await call_next(request)
|
||||
|
||||
if request.state.user_tier in ("anonymous",):
|
||||
user_tier = getattr(request.state, "user_tier", None)
|
||||
if user_tier in ("anonymous", None):
|
||||
return await call_next(request)
|
||||
|
||||
user_id = request.state.user_id
|
||||
tier = request.state.user_tier
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
tier = user_tier
|
||||
|
||||
if tier == "enterprise":
|
||||
return await call_next(request)
|
||||
@@ -102,7 +174,7 @@ class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
r = aioredis.from_url(settings.REDIS_URL)
|
||||
r = await get_redis()
|
||||
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
|
||||
current = await r.incr(key)
|
||||
await r.expire(key, 86400)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from app.config import settings
|
||||
import redis.asyncio as aioredis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_redis():
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = ConnectionPool.from_url(settings.REDIS_URL, max_connections=20)
|
||||
return aioredis.Redis(connection_pool=_pool)
|
||||
|
||||
|
||||
async def close_redis():
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.disconnect()
|
||||
_pool = None
|
||||
@@ -1,18 +1,24 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
try:
|
||||
password_bytes = plain.encode("utf-8")
|
||||
if isinstance(hashed, str):
|
||||
hashed = hashed.encode("utf-8")
|
||||
return bcrypt.checkpw(password_bytes, hashed)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
password_bytes = password[:72].encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password_bytes, salt).decode("utf-8")
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
|
||||
+34
-3
@@ -2,12 +2,30 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.config import settings
|
||||
from app.core.exceptions import register_exception_handlers
|
||||
from app.core.middleware import TierMiddleware, QuotaMiddleware
|
||||
from app.core.middleware import TierMiddleware, QuotaMiddleware, RateLimitMiddleware
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn=settings.SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
environment="production" if not settings.DEBUG else "development",
|
||||
integrations=[
|
||||
FastApiIntegration(),
|
||||
SqlalchemyIntegration(),
|
||||
],
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
except (ImportError, Exception) as e:
|
||||
logger.info(f"Sentry not configured: {e}")
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version="1.0.0",
|
||||
@@ -23,8 +41,9 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(TierMiddleware)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
app.add_middleware(QuotaMiddleware)
|
||||
app.add_middleware(TierMiddleware)
|
||||
|
||||
register_exception_handlers(app)
|
||||
|
||||
@@ -34,17 +53,29 @@ async def health():
|
||||
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
||||
|
||||
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push, admin, analytics, teams, onboarding, notification, feedback, payment, interaction, silent_pattern, training, followup
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
||||
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
||||
app.include_router(translate.router, prefix="/api/v1/translate", tags=["translate"])
|
||||
app.include_router(translate.public_router, prefix="/api/v1/translate/public", tags=["translate-public"])
|
||||
app.include_router(customer.router, prefix="/api/v1/customers", tags=["customers"])
|
||||
app.include_router(quotation.router, prefix="/api/v1/quotations", tags=["quotations"])
|
||||
app.include_router(whatsapp.router, prefix="/api/v1/whatsapp", tags=["whatsapp"])
|
||||
app.include_router(product.router, prefix="/api/v1/products", tags=["products"])
|
||||
app.include_router(exchange.router, prefix="/api/v1/exchange", tags=["exchange"])
|
||||
app.include_router(push.router, prefix="/api/v1/push", tags=["push"])
|
||||
app.include_router(admin.router, prefix="/api/v1/admin", tags=["admin"])
|
||||
app.include_router(analytics.router, prefix="/api/v1/analytics", tags=["analytics"])
|
||||
app.include_router(teams.router, prefix="/api/v1/teams", tags=["teams"])
|
||||
app.include_router(onboarding.router, prefix="/api/v1/onboarding", tags=["onboarding"])
|
||||
app.include_router(notification.router, prefix="/api/v1/notifications", tags=["notifications"])
|
||||
app.include_router(feedback.router, prefix="/api/v1/feedback", tags=["feedback"])
|
||||
app.include_router(payment.router, prefix="/api/v1/payment", tags=["payment"])
|
||||
app.include_router(interaction.router, prefix="/api/v1/interaction", tags=["interaction"])
|
||||
app.include_router(silent_pattern.router, prefix="/api/v1/silent-pattern", tags=["silent-pattern"])
|
||||
app.include_router(training.router, prefix="/api/v1/training", tags=["training"])
|
||||
app.include_router(followup.router, prefix="/api/v1/followup", tags=["followup"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,10 +2,26 @@ from .user import User, Product
|
||||
from .customer import Customer, Conversation, Message
|
||||
from .quotation import Quotation, QuotationItem
|
||||
from .corpus import CorpusEntry
|
||||
from .team import Team, TeamMember
|
||||
from .analytics import UsageLog
|
||||
from .notification import Notification
|
||||
from .feedback import Feedback
|
||||
from .subscription import Subscription
|
||||
from .preference import PreferenceAnalysis, MarketingEffect
|
||||
from .device import Device
|
||||
from .followup import FollowupStrategy, FollowupLog
|
||||
|
||||
__all__ = [
|
||||
"User", "Product",
|
||||
"Customer", "Conversation", "Message",
|
||||
"Quotation", "QuotationItem",
|
||||
"CorpusEntry",
|
||||
"Team", "TeamMember",
|
||||
"UsageLog",
|
||||
"Notification",
|
||||
"Feedback",
|
||||
"Subscription",
|
||||
"PreferenceAnalysis", "MarketingEffect",
|
||||
"Device",
|
||||
"FollowupStrategy", "FollowupLog",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class UsageLog(Base):
|
||||
__tablename__ = "usage_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
team_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
action = Column(String(100), nullable=False)
|
||||
detail = Column(JSONB, default={})
|
||||
ip_address = Column(String(50))
|
||||
user_agent = Column(String(255))
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, Float
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, Float, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
@@ -21,6 +20,6 @@ class CorpusEntry(Base):
|
||||
user_edited = Column(Boolean, default=False)
|
||||
user_rating = Column(Integer)
|
||||
usage_count = Column(Integer, default=0)
|
||||
embedding = Column(Vector(768))
|
||||
metadata = Column(JSONB, default={})
|
||||
embedding = Column(JSONB)
|
||||
entry_metadata = Column("metadata", JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -10,7 +10,7 @@ class Customer(Base):
|
||||
__tablename__ = "customers"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
company = Column(String(255))
|
||||
country = Column(String(100))
|
||||
@@ -38,7 +38,7 @@ class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
customer_id = Column(UUID(as_uuid=True), ForeignKey("customers.id"), nullable=False, index=True)
|
||||
channel = Column(String(50), default="whatsapp")
|
||||
topic = Column(String(255))
|
||||
@@ -66,7 +66,7 @@ class Message(Base):
|
||||
selected_suggestion = Column(Integer)
|
||||
user_edited = Column(Text)
|
||||
status = Column(String(50), default="sent")
|
||||
metadata = Column(JSONB, default={})
|
||||
msg_metadata = Column("metadata", JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Device(Base):
|
||||
__tablename__ = "devices"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
platform = Column(String(50), default="weapp")
|
||||
push_token = Column(String(500))
|
||||
client_id = Column(String(255), nullable=False)
|
||||
device_info = Column(JSONB, default={})
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Feedback(Base):
|
||||
__tablename__ = "feedbacks"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
category = Column(String(50), default="general")
|
||||
content = Column(Text, nullable=False)
|
||||
contact = Column(String(100))
|
||||
status = Column(String(20), default="pending")
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,51 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class FollowupStrategy(Base):
|
||||
__tablename__ = "followup_strategies"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
trigger_condition = Column(JSONB, default={
|
||||
"min_silence_days": 3,
|
||||
"max_silence_days": 999,
|
||||
"min_health_score": 0,
|
||||
"max_health_score": 100,
|
||||
"status_filter": ["lead", "negotiating"],
|
||||
})
|
||||
channel = Column(String(50), default="whatsapp")
|
||||
ai_prompt_template = Column(Text)
|
||||
priority = Column(Integer, default=0)
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class FollowupLog(Base):
|
||||
__tablename__ = "followup_logs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
customer_id = Column(UUID(as_uuid=True), ForeignKey("customers.id"), nullable=False, index=True)
|
||||
strategy_id = Column(UUID(as_uuid=True), ForeignKey("followup_strategies.id"), nullable=True)
|
||||
status = Column(String(50), default="pending")
|
||||
channel = Column(String(50), default="whatsapp")
|
||||
content = Column(Text)
|
||||
ai_generated_content = Column(Text)
|
||||
user_edited_content = Column(Text)
|
||||
health_score_at_time = Column(Integer)
|
||||
silence_days_at_time = Column(Integer)
|
||||
sent_at = Column(DateTime)
|
||||
replied_at = Column(DateTime)
|
||||
response_status = Column(String(50))
|
||||
log_metadata = Column("metadata", JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
customer = relationship("Customer", backref="followup_logs")
|
||||
@@ -0,0 +1,20 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notifications"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
notification_type = Column(String(50), default="system")
|
||||
reference_type = Column(String(50))
|
||||
reference_id = Column(String(255))
|
||||
is_read = Column(Boolean, default=False)
|
||||
notify_metadata = Column("metadata", JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,40 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class PreferenceAnalysis(Base):
|
||||
__tablename__ = "preference_analyses"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True, unique=True)
|
||||
task_type = Column(String(50), nullable=False, default="reply")
|
||||
preferred_tone = Column(String(50))
|
||||
preferred_style = Column(String(50))
|
||||
common_replacements = Column(JSONB, default=[])
|
||||
avg_formality_score = Column(Float, default=0.5)
|
||||
greeting_style = Column(String(100))
|
||||
sign_off_style = Column(String(100))
|
||||
analysis_data = Column(JSONB, default={})
|
||||
confidence = Column(Float, default=0.0)
|
||||
interaction_count = Column(Integer, default=0)
|
||||
last_analysis_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class MarketingEffect(Base):
|
||||
__tablename__ = "marketing_effects"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
content_hash = Column(String(64), nullable=False, index=True)
|
||||
product_id = Column(UUID(as_uuid=True))
|
||||
product_name = Column(String(255))
|
||||
channel = Column(String(50), default="copy")
|
||||
event_type = Column(String(50), nullable=False)
|
||||
target_audience = Column(String(255))
|
||||
effect_metadata = Column("metadata", JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -10,7 +10,7 @@ class Quotation(Base):
|
||||
__tablename__ = "quotations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
customer_id = Column(UUID(as_uuid=True), ForeignKey("customers.id"), nullable=False)
|
||||
title = Column(String(255))
|
||||
status = Column(String(50), default="draft")
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Float, Boolean, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
plan = Column(String(50), nullable=False)
|
||||
status = Column(String(20), default="active")
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime)
|
||||
auto_renew = Column(Boolean, default=True)
|
||||
payment_provider = Column(String(50), default="wechat")
|
||||
payment_id = Column(String(255))
|
||||
amount = Column(Float)
|
||||
currency = Column(String(10), default="CNY")
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -0,0 +1,38 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Text, Integer
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Team(Base):
|
||||
__tablename__ = "teams"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False)
|
||||
owner_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
description = Column(Text)
|
||||
member_count = Column(Integer, default=0)
|
||||
max_members = Column(Integer, default=5)
|
||||
tier = Column(String(50), default="free")
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
members = relationship("TeamMember", back_populates="team", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class TeamMember(Base):
|
||||
__tablename__ = "team_members"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
team_id = Column(UUID(as_uuid=True), ForeignKey("teams.id"), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
role = Column(String(50), default="member")
|
||||
invited_by = Column(UUID(as_uuid=True))
|
||||
status = Column(String(50), default="active")
|
||||
joined_at = Column(DateTime, default=datetime.utcnow)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
team = relationship("Team", back_populates="members")
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
@@ -15,6 +15,7 @@ class User(Base):
|
||||
username = Column(String(100))
|
||||
password_hash = Column(String(255))
|
||||
tier = Column(String(50), default="free")
|
||||
role = Column(String(20), default="user")
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
@@ -35,7 +36,7 @@ class Product(Base):
|
||||
__tablename__ = "products"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
name_en = Column(String(255))
|
||||
description = Column(Text)
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
from typing import Dict, Any, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.user import User
|
||||
from app.models.team import Team, TeamMember
|
||||
from app.models.analytics import UsageLog
|
||||
from app.models.customer import Customer
|
||||
from app.models.quotation import Quotation
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_dashboard(self) -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
user_count = await self.db.execute(select(func.count(User.id)))
|
||||
team_count = await self.db.execute(select(func.count(Team.id)))
|
||||
customer_count = await self.db.execute(select(func.count(Customer.id)))
|
||||
quotation_count = await self.db.execute(select(func.count(Quotation.id)))
|
||||
|
||||
today_logs = await self.db.execute(
|
||||
select(func.count(UsageLog.id)).where(UsageLog.created_at >= today_start)
|
||||
)
|
||||
total_logs = await self.db.execute(select(func.count(UsageLog.id)))
|
||||
|
||||
recent_users_result = await self.db.execute(
|
||||
select(User).order_by(User.created_at.desc()).limit(5)
|
||||
)
|
||||
recent_users = recent_users_result.scalars().all()
|
||||
|
||||
return {
|
||||
"users": {
|
||||
"total": user_count.scalar() or 0,
|
||||
},
|
||||
"teams": {
|
||||
"total": team_count.scalar() or 0,
|
||||
},
|
||||
"customers": {
|
||||
"total": customer_count.scalar() or 0,
|
||||
},
|
||||
"quotations": {
|
||||
"total": quotation_count.scalar() or 0,
|
||||
},
|
||||
"usage": {
|
||||
"today": today_logs.scalar() or 0,
|
||||
"total": total_logs.scalar() or 0,
|
||||
},
|
||||
"recent_users": [
|
||||
{
|
||||
"id": str(u.id),
|
||||
"username": u.username,
|
||||
"tier": u.tier,
|
||||
"is_active": u.is_active,
|
||||
"created_at": u.created_at.isoformat() if u.created_at else None,
|
||||
}
|
||||
for u in recent_users
|
||||
],
|
||||
}
|
||||
|
||||
async def list_users(self, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(User).order_by(User.created_at.desc()).offset((page - 1) * size).limit(size)
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
users = result.scalars().all()
|
||||
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(u.id),
|
||||
"username": u.username,
|
||||
"phone": u.phone,
|
||||
"tier": u.tier,
|
||||
"is_active": u.is_active,
|
||||
"created_at": u.created_at.isoformat() if u.created_at else None,
|
||||
}
|
||||
for u in users
|
||||
],
|
||||
"total": total.scalar(),
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def update_user_tier(self, user_id: str, tier: str) -> bool:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
return False
|
||||
user.tier = tier
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def toggle_user_active(self, user_id: str) -> bool:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
return False
|
||||
user.is_active = not user.is_active
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def get_system_health(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
async def log_usage(self, user_id: str, action: str, detail: Dict = None, ip: str = None, ua: str = None):
|
||||
try:
|
||||
log = UsageLog(
|
||||
user_id=user_id,
|
||||
action=action,
|
||||
detail=detail or {},
|
||||
ip_address=ip,
|
||||
user_agent=ua,
|
||||
)
|
||||
self.db.add(log)
|
||||
await self.db.flush()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log usage: {e}")
|
||||
@@ -0,0 +1,191 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, extract
|
||||
from app.models.customer import Customer, Conversation, Message
|
||||
from app.models.quotation import Quotation
|
||||
from app.models.analytics import UsageLog
|
||||
from app.models.user import User
|
||||
from app.models.preference import MarketingEffect
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalyticsService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_customer_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
total = await self.db.execute(
|
||||
select(func.count(Customer.id)).where(Customer.user_id == user_id)
|
||||
)
|
||||
by_status = await self.db.execute(
|
||||
select(Customer.status, func.count(Customer.id))
|
||||
.where(Customer.user_id == user_id)
|
||||
.group_by(Customer.status)
|
||||
)
|
||||
by_country = await self.db.execute(
|
||||
select(Customer.country, func.count(Customer.id))
|
||||
.where(Customer.user_id == user_id)
|
||||
.where(Customer.country.isnot(None))
|
||||
.group_by(Customer.country)
|
||||
.order_by(func.count(Customer.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
silent_3 = await self.db.execute(
|
||||
select(func.count(Customer.id)).where(
|
||||
and_(
|
||||
Customer.user_id == user_id,
|
||||
Customer.last_contact_at.isnot(None),
|
||||
Customer.last_contact_at < now - timedelta(days=3),
|
||||
Customer.status.in_(["lead", "negotiating"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total.scalar() or 0,
|
||||
"by_status": {row[0] or "unknown": row[1] for row in by_status.all()},
|
||||
"by_country": {row[0] or "unknown": row[1] for row in by_country.all()},
|
||||
"silent_customers": silent_3.scalar() or 0,
|
||||
}
|
||||
|
||||
async def get_translation_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
today_count = await self.db.execute(
|
||||
select(func.count(UsageLog.id)).where(
|
||||
and_(
|
||||
UsageLog.user_id == user_id,
|
||||
UsageLog.action == "translate",
|
||||
UsageLog.created_at >= today_start,
|
||||
)
|
||||
)
|
||||
)
|
||||
total_count = await self.db.execute(
|
||||
select(func.count(UsageLog.id)).where(
|
||||
and_(UsageLog.user_id == user_id, UsageLog.action == "translate")
|
||||
)
|
||||
)
|
||||
|
||||
daily_result = await self.db.execute(
|
||||
select(
|
||||
extract("year", UsageLog.created_at),
|
||||
extract("month", UsageLog.created_at),
|
||||
extract("day", UsageLog.created_at),
|
||||
func.count(UsageLog.id),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
UsageLog.user_id == user_id,
|
||||
UsageLog.action == "translate",
|
||||
UsageLog.created_at >= now - timedelta(days=30),
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
extract("year", UsageLog.created_at),
|
||||
extract("month", UsageLog.created_at),
|
||||
extract("day", UsageLog.created_at),
|
||||
)
|
||||
.order_by(
|
||||
extract("year", UsageLog.created_at),
|
||||
extract("month", UsageLog.created_at),
|
||||
extract("day", UsageLog.created_at),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"today": today_count.scalar() or 0,
|
||||
"total": total_count.scalar() or 0,
|
||||
"daily": [
|
||||
{
|
||||
"date": f"{int(r[0])}-{int(r[1]):02d}-{int(r[2]):02d}",
|
||||
"count": r[3],
|
||||
}
|
||||
for r in daily_result.all()
|
||||
],
|
||||
}
|
||||
|
||||
async def get_quotation_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
total = await self.db.execute(
|
||||
select(func.count(Quotation.id)).where(Quotation.user_id == user_id)
|
||||
)
|
||||
by_status = await self.db.execute(
|
||||
select(Quotation.status, func.count(Quotation.id))
|
||||
.where(Quotation.user_id == user_id)
|
||||
.group_by(Quotation.status)
|
||||
)
|
||||
total_value = await self.db.execute(
|
||||
select(func.sum(Quotation.total)).where(
|
||||
and_(Quotation.user_id == user_id, Quotation.status == "accepted")
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total.scalar() or 0,
|
||||
"by_status": {row[0] or "draft": row[1] for row in by_status.all()},
|
||||
"total_accepted_value": float(total_value.scalar() or 0),
|
||||
}
|
||||
|
||||
async def get_message_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
total_msgs = await self.db.execute(
|
||||
select(func.count(Message.id))
|
||||
.join(Conversation, Message.conversation_id == Conversation.id)
|
||||
.where(Conversation.user_id == user_id)
|
||||
)
|
||||
today_msgs = await self.db.execute(
|
||||
select(func.count(Message.id))
|
||||
.join(Conversation, Message.conversation_id == Conversation.id)
|
||||
.where(
|
||||
and_(
|
||||
Conversation.user_id == user_id,
|
||||
Message.created_at >= today_start,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total_msgs.scalar() or 0,
|
||||
"today": today_msgs.scalar() or 0,
|
||||
}
|
||||
|
||||
async def get_marketing_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
total = await self.db.execute(
|
||||
select(func.count(MarketingEffect.id)).where(MarketingEffect.user_id == user_id)
|
||||
)
|
||||
copy_count = await self.db.execute(
|
||||
select(func.count(MarketingEffect.id)).where(
|
||||
and_(MarketingEffect.user_id == user_id, MarketingEffect.event_type == "copy")
|
||||
)
|
||||
)
|
||||
send_count = await self.db.execute(
|
||||
select(func.count(MarketingEffect.id)).where(
|
||||
and_(MarketingEffect.user_id == user_id, MarketingEffect.event_type == "send")
|
||||
)
|
||||
)
|
||||
top_products = await self.db.execute(
|
||||
select(MarketingEffect.product_name, func.count(MarketingEffect.id))
|
||||
.where(
|
||||
and_(
|
||||
MarketingEffect.user_id == user_id,
|
||||
MarketingEffect.product_name.isnot(None),
|
||||
)
|
||||
)
|
||||
.group_by(MarketingEffect.product_name)
|
||||
.order_by(func.count(MarketingEffect.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_events": total.scalar() or 0,
|
||||
"copy_count": copy_count.scalar() or 0,
|
||||
"send_count": send_count.scalar() or 0,
|
||||
"top_products": [{"name": r[0], "count": r[1]} for r in top_products.all()],
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CorpusTrainer:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def compute_embeddings(self, batch_size: int = 50) -> Dict[str, Any]:
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
result = await self.db.execute(
|
||||
select(CorpusEntry).where(CorpusEntry.embedding.is_(None)).limit(batch_size)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
|
||||
updated = 0
|
||||
for entry in entries:
|
||||
try:
|
||||
embedding = await self._generate_embedding(entry.source_text)
|
||||
if embedding:
|
||||
entry.embedding = embedding
|
||||
updated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Embedding failed for entry {entry.id}: {e}")
|
||||
|
||||
await self.db.flush()
|
||||
return {"processed": len(entries), "updated": updated}
|
||||
|
||||
async def score_entries(self, batch_size: int = 100) -> Dict[str, Any]:
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
result = await self.db.execute(
|
||||
select(CorpusEntry)
|
||||
.where(CorpusEntry.quality_score.is_(None))
|
||||
.limit(batch_size)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
|
||||
updated = 0
|
||||
for entry in entries:
|
||||
score = self._calculate_quality_score(entry)
|
||||
entry.quality_score = score
|
||||
updated += 1
|
||||
|
||||
await self.db.flush()
|
||||
return {"processed": len(entries), "updated": updated}
|
||||
|
||||
async def deduplicate(self) -> Dict[str, Any]:
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
subquery = (
|
||||
select(
|
||||
CorpusEntry.source_text,
|
||||
CorpusEntry.task_type,
|
||||
func.min(CorpusEntry.id).label("keep_id"),
|
||||
)
|
||||
.group_by(CorpusEntry.source_text, CorpusEntry.task_type)
|
||||
.having(func.count(CorpusEntry.id) > 1)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(CorpusEntry).where(
|
||||
and_(
|
||||
CorpusEntry.source_text == subquery.c.source_text,
|
||||
CorpusEntry.task_type == subquery.c.task_type,
|
||||
CorpusEntry.id != subquery.c.keep_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
duplicates = result.scalars().all()
|
||||
|
||||
for dup in duplicates:
|
||||
await self.db.delete(dup)
|
||||
|
||||
await self.db.flush()
|
||||
return {"duplicates_removed": len(duplicates)}
|
||||
|
||||
async def prune_low_quality(self, min_score: float = 0.2, max_age_days: int = 90) -> Dict[str, Any]:
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
cutoff = datetime.utcnow() - timedelta(days=max_age_days)
|
||||
result = await self.db.execute(
|
||||
select(CorpusEntry).where(
|
||||
and_(
|
||||
CorpusEntry.quality_score < min_score,
|
||||
CorpusEntry.created_at < cutoff,
|
||||
CorpusEntry.usage_count.is_(None) | (CorpusEntry.usage_count < 2),
|
||||
)
|
||||
)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
|
||||
for e in entries:
|
||||
await self.db.delete(e)
|
||||
|
||||
await self.db.flush()
|
||||
return {"pruned": len(entries)}
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
total = await self.db.execute(select(func.count(CorpusEntry.id)))
|
||||
by_type = await self.db.execute(
|
||||
select(CorpusEntry.task_type, func.count(CorpusEntry.id))
|
||||
.group_by(CorpusEntry.task_type)
|
||||
)
|
||||
with_embeddings = await self.db.execute(
|
||||
select(func.count(CorpusEntry.id)).where(CorpusEntry.embedding.isnot(None))
|
||||
)
|
||||
high_quality = await self.db.execute(
|
||||
select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score >= 0.7)
|
||||
)
|
||||
low_quality = await self.db.execute(
|
||||
select(func.count(CorpusEntry.id)).where(CorpusEntry.quality_score < 0.3)
|
||||
)
|
||||
|
||||
return {
|
||||
"total_entries": total.scalar() or 0,
|
||||
"by_task_type": {row[0]: row[1] for row in by_type.all()},
|
||||
"with_embeddings": with_embeddings.scalar() or 0,
|
||||
"high_quality": high_quality.scalar() or 0,
|
||||
"low_quality": low_quality.scalar() or 0,
|
||||
}
|
||||
|
||||
async def run_pipeline(self) -> Dict[str, Any]:
|
||||
dedup_result = await self.deduplicate()
|
||||
score_result = await self.score_entries()
|
||||
embed_result = await self.compute_embeddings()
|
||||
prune_result = await self.prune_low_quality()
|
||||
stats = await self.get_stats()
|
||||
|
||||
return {
|
||||
"deduplication": dedup_result,
|
||||
"scoring": score_result,
|
||||
"embeddings": embed_result,
|
||||
"pruning": prune_result,
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
def _calculate_quality_score(self, entry) -> float:
|
||||
score = 0.5
|
||||
|
||||
if entry.user_rating:
|
||||
score = entry.user_rating / 5.0
|
||||
|
||||
if entry.user_edited:
|
||||
score = max(score - 0.1, 0)
|
||||
|
||||
if entry.usage_count and entry.usage_count > 5:
|
||||
score = min(score + 0.15, 1.0)
|
||||
|
||||
src_len = len(entry.source_text) if entry.source_text else 0
|
||||
tgt_len = len(entry.target_text) if entry.target_text else 0
|
||||
if src_len > 10 and tgt_len > 10:
|
||||
score = min(score + 0.1, 1.0)
|
||||
if src_len < 3 or tgt_len < 3:
|
||||
score = max(score - 0.3, 0)
|
||||
|
||||
return round(score, 2)
|
||||
|
||||
async def _generate_embedding(self, text: str) -> Optional[List[float]]:
|
||||
try:
|
||||
from app.config import settings
|
||||
import httpx
|
||||
|
||||
if settings.OPENAI_API_KEY:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
"https://api.openai.com/v1/embeddings",
|
||||
headers={"Authorization": f"Bearer {settings.OPENAI_API_KEY}"},
|
||||
json={"model": "text-embedding-3-small", "input": text[:8000]},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.warning(f"Embedding generation failed: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,333 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, desc
|
||||
from app.models.customer import Customer, Message, Conversation
|
||||
from app.models.quotation import Quotation
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEAL_SIGNAL_KEYWORDS = [
|
||||
"moq", "minimum order", "sample", "certification", "certificate",
|
||||
"fob", "cif", "lead time", "delivery time", "shipping",
|
||||
"payment term", "tt", "lc", "deposit", "price", "quotation",
|
||||
"order", "purchase", "buy", "interested", "inquiry", "rfq",
|
||||
]
|
||||
|
||||
POSITIVE_WORDS = ["yes", "interested", "good", "great", "perfect", "thanks", "thank you", "proceed", "confirm", "agree"]
|
||||
NEGATIVE_WORDS = ["no", "not interested", "too expensive", "high price", "over budget", "not now", "later", "maybe later"]
|
||||
|
||||
|
||||
class CustomerHealthService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_health_overview(self, user_id: str) -> Dict[str, Any]:
|
||||
customers_result = await self.db.execute(
|
||||
select(Customer.id, Customer.status, Customer.last_contact_at).where(
|
||||
Customer.user_id == user_id
|
||||
)
|
||||
)
|
||||
rows = customers_result.all()
|
||||
total = len(rows)
|
||||
active = 0
|
||||
watch = 0
|
||||
critical = 0
|
||||
for row in rows:
|
||||
score = self._calculate_silence_score(row.last_contact_at)
|
||||
status_weight = self._status_weight(row.status)
|
||||
combined = score * 0.7 + status_weight * 0.3
|
||||
if combined >= 70:
|
||||
active += 1
|
||||
elif combined >= 40:
|
||||
watch += 1
|
||||
else:
|
||||
critical += 1
|
||||
return {
|
||||
"total": total,
|
||||
"active": active,
|
||||
"watch": watch,
|
||||
"critical": critical,
|
||||
}
|
||||
|
||||
async def get_customer_health(self, user_id: str, customer_id: str) -> Optional[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
customer = result.scalar_one_or_none()
|
||||
if not customer:
|
||||
return None
|
||||
return await self._compute_full_health(user_id, customer)
|
||||
|
||||
async def get_all_health_scores(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
customers_result = await self.db.execute(
|
||||
select(Customer).where(Customer.user_id == user_id).order_by(Customer.updated_at.desc())
|
||||
)
|
||||
customers = customers_result.scalars().all()
|
||||
results = []
|
||||
for c in customers:
|
||||
health = await self._compute_full_health(user_id, c)
|
||||
results.append(health)
|
||||
return results
|
||||
|
||||
async def _compute_full_health(self, user_id: str, customer: Customer) -> Dict[str, Any]:
|
||||
response_trend = await self._calc_response_trend(customer.id)
|
||||
sentiment = await self._calc_sentiment(customer.id)
|
||||
inquiry_depth = await self._calc_inquiry_depth(customer.id)
|
||||
silence_score = self._calculate_silence_score(customer.last_contact_at)
|
||||
business_value = await self._calc_business_value(customer.id)
|
||||
|
||||
silence_days = self._silence_days(customer.last_contact_at)
|
||||
dimensions = {
|
||||
"response_trend": response_trend,
|
||||
"sentiment": sentiment,
|
||||
"inquiry_depth": inquiry_depth,
|
||||
"silence": {"score": silence_score, "days": silence_days},
|
||||
"business_value": business_value,
|
||||
}
|
||||
result = self.calc_total_score(dimensions)
|
||||
|
||||
return {
|
||||
"customer_id": str(customer.id),
|
||||
"customer_name": customer.name,
|
||||
"status": customer.status,
|
||||
"total_score": result["total_score"],
|
||||
"grade": result["grade"],
|
||||
"dimensions": dimensions,
|
||||
"suggestion": self._suggestion(result["grade"], customer),
|
||||
}
|
||||
|
||||
async def _calc_response_trend(self, customer_id: str) -> Dict[str, Any]:
|
||||
now_7d_ago = datetime.utcnow() - timedelta(days=7)
|
||||
prev_7d_ago = datetime.utcnow() - timedelta(days=14)
|
||||
|
||||
recent_result = await self.db.execute(
|
||||
select(func.avg(
|
||||
func.extract("epoch", Message.created_at) -
|
||||
func.extract("epoch", func.lag(Message.created_at).over(order_by=Message.created_at))
|
||||
)).where(
|
||||
and_(
|
||||
Message.conversation_id == select(Conversation.id).where(
|
||||
Conversation.customer_id == customer_id
|
||||
).limit(1).scalar_subquery(),
|
||||
Message.direction == "inbound",
|
||||
Message.created_at >= now_7d_ago,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
previous_result = await self.db.execute(
|
||||
select(func.avg(
|
||||
func.extract("epoch", Message.created_at) -
|
||||
func.extract("epoch", func.lag(Message.created_at).over(order_by=Message.created_at))
|
||||
)).where(
|
||||
and_(
|
||||
Message.conversation_id == select(Conversation.id).where(
|
||||
Conversation.customer_id == customer_id
|
||||
).limit(1).scalar_subquery(),
|
||||
Message.direction == "inbound",
|
||||
Message.created_at >= prev_7d_ago,
|
||||
Message.created_at < now_7d_ago,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
recent_avg = recent_result.scalar()
|
||||
prev_avg = previous_result.scalar()
|
||||
|
||||
recent_hours = (recent_avg / 3600) if recent_avg else None
|
||||
prev_hours = (prev_avg / 3600) if prev_avg else None
|
||||
return self.calc_response_score(recent_hours, prev_hours)
|
||||
|
||||
async def _calc_sentiment(self, customer_id: str) -> Dict[str, Any]:
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation.id).where(
|
||||
Conversation.customer_id == customer_id
|
||||
).order_by(Conversation.created_at.desc()).limit(1)
|
||||
)
|
||||
conv_id = conv_result.scalar_one_or_none()
|
||||
if not conv_id:
|
||||
return {"score": 50, "label": "neutral", "last_messages": []}
|
||||
|
||||
msg_result = await self.db.execute(
|
||||
select(Message.content).where(
|
||||
and_(
|
||||
Message.conversation_id == conv_id,
|
||||
Message.direction == "inbound",
|
||||
)
|
||||
).order_by(desc(Message.created_at)).limit(3)
|
||||
)
|
||||
messages = list(msg_result.scalars().all())
|
||||
return self.calc_sentiment_score(messages)
|
||||
|
||||
async def _calc_inquiry_depth(self, customer_id: str) -> Dict[str, Any]:
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation.id).where(
|
||||
Conversation.customer_id == customer_id
|
||||
).order_by(Conversation.created_at.desc()).limit(1)
|
||||
)
|
||||
conv_id = conv_result.scalar_one_or_none()
|
||||
if not conv_id:
|
||||
return {"score": 0, "signals_found": [], "signal_count": 0}
|
||||
|
||||
msg_result = await self.db.execute(
|
||||
select(Message.content).where(
|
||||
and_(
|
||||
Message.conversation_id == conv_id,
|
||||
Message.direction == "inbound",
|
||||
)
|
||||
).order_by(desc(Message.created_at)).limit(20)
|
||||
)
|
||||
messages = list(msg_result.scalars().all())
|
||||
return self.calc_inquiry_depth_score(messages)
|
||||
|
||||
@staticmethod
|
||||
def calculate_silence_score(last_contact_at: Optional[datetime]) -> float:
|
||||
days = CustomerHealthService.silence_days(last_contact_at)
|
||||
return max(0, min(100, 100 - (days / 14) * 100))
|
||||
|
||||
@staticmethod
|
||||
def silence_days(last_contact_at: Optional[datetime]) -> int:
|
||||
if not last_contact_at:
|
||||
return 999
|
||||
return (datetime.utcnow() - last_contact_at).days
|
||||
|
||||
@staticmethod
|
||||
def status_weight(status: Optional[str]) -> float:
|
||||
mapping = {"customer": 100, "negotiating": 70, "lead": 40, "lost": 10}
|
||||
return mapping.get(status, 40)
|
||||
|
||||
@staticmethod
|
||||
def grade(score: float) -> str:
|
||||
if score >= 80:
|
||||
return "active"
|
||||
elif score >= 50:
|
||||
return "watch"
|
||||
else:
|
||||
return "critical"
|
||||
|
||||
@staticmethod
|
||||
def calc_response_score(recent_hours: Optional[float], prev_hours: Optional[float]) -> Dict[str, Any]:
|
||||
if recent_hours is None and prev_hours is None:
|
||||
return {"score": 50, "recent_avg_hours": None, "trend": "stable"}
|
||||
if recent_hours is None:
|
||||
return {"score": 30, "recent_avg_hours": None, "trend": "declining"}
|
||||
if prev_hours is None or prev_hours == 0:
|
||||
score = max(0, min(100, 100 - recent_hours * 5))
|
||||
return {"score": round(score), "recent_avg_hours": round(recent_hours, 1), "trend": "stable"}
|
||||
if recent_hours < prev_hours:
|
||||
score = max(0, min(100, 100 - recent_hours * 5))
|
||||
return {"score": round(score), "recent_avg_hours": round(recent_hours, 1), "trend": "improving"}
|
||||
else:
|
||||
score = max(0, min(100, 80 - recent_hours * 3))
|
||||
return {"score": round(score), "recent_avg_hours": round(recent_hours, 1), "trend": "declining"}
|
||||
|
||||
@staticmethod
|
||||
def calc_sentiment_score(messages: List[str]) -> Dict[str, Any]:
|
||||
if not messages:
|
||||
return {"score": 50, "label": "neutral", "last_messages": []}
|
||||
positive = 0
|
||||
negative = 0
|
||||
for msg in messages:
|
||||
lower = msg.lower()
|
||||
if any(w in lower for w in POSITIVE_WORDS):
|
||||
positive += 1
|
||||
if any(w in lower for w in NEGATIVE_WORDS):
|
||||
negative += 1
|
||||
if positive > negative:
|
||||
return {"score": 80, "label": "positive", "last_messages": messages}
|
||||
elif negative > positive:
|
||||
return {"score": 20, "label": "negative", "last_messages": messages}
|
||||
else:
|
||||
return {"score": 50, "label": "neutral", "last_messages": messages}
|
||||
|
||||
@staticmethod
|
||||
def calc_inquiry_depth_score(messages: List[str]) -> Dict[str, Any]:
|
||||
found_signals = []
|
||||
for msg in messages:
|
||||
lower = msg.lower()
|
||||
for kw in DEAL_SIGNAL_KEYWORDS:
|
||||
if kw in lower and kw not in found_signals:
|
||||
found_signals.append(kw)
|
||||
count = len(found_signals)
|
||||
if count >= 5:
|
||||
score = 100
|
||||
elif count >= 3:
|
||||
score = 75
|
||||
elif count >= 1:
|
||||
score = 50
|
||||
else:
|
||||
score = 0
|
||||
return {"score": score, "signals_found": found_signals, "signal_count": count}
|
||||
|
||||
@staticmethod
|
||||
def calc_business_value_score(total_value: float) -> Dict[str, Any]:
|
||||
if total_value >= 100000:
|
||||
score = 100
|
||||
elif total_value >= 50000:
|
||||
score = 80
|
||||
elif total_value >= 10000:
|
||||
score = 60
|
||||
elif total_value >= 1000:
|
||||
score = 40
|
||||
elif total_value > 0:
|
||||
score = 20
|
||||
else:
|
||||
score = 0
|
||||
return {"score": score, "total_value": round(total_value, 2)}
|
||||
|
||||
@staticmethod
|
||||
def calc_total_score(dimensions: Dict[str, Any]) -> Dict[str, Any]:
|
||||
total = (
|
||||
dimensions.get("response_trend", {}).get("score", 0) * 0.25
|
||||
+ dimensions.get("sentiment", {}).get("score", 0) * 0.20
|
||||
+ dimensions.get("inquiry_depth", {}).get("score", 0) * 0.20
|
||||
+ dimensions.get("silence", {}).get("score", 0) * 0.20
|
||||
+ dimensions.get("business_value", {}).get("score", 0) * 0.15
|
||||
)
|
||||
return {"total_score": round(total, 1), "grade": CustomerHealthService.grade(total)}
|
||||
|
||||
@staticmethod
|
||||
def suggestion(grade: str, silence_days: int, status: Optional[str]) -> str:
|
||||
if grade == "active":
|
||||
return "保持正常跟进,客户状态良好"
|
||||
elif grade == "watch":
|
||||
if silence_days >= 3:
|
||||
return f"客户已沉默{silence_days}天,建议3天内安排跟进"
|
||||
return "客户活跃度下降,建议关注"
|
||||
else:
|
||||
if status in ("lead", "negotiating"):
|
||||
return f"客户已沉默{silence_days}天,建议立即跟进,提供优惠或新产品信息"
|
||||
return f"客户已沉默{silence_days}天,建议重新激活"
|
||||
|
||||
def _calculate_silence_score(self, last_contact_at: Optional[datetime]) -> float:
|
||||
return self.calculate_silence_score(last_contact_at)
|
||||
|
||||
def _silence_days(self, last_contact_at: Optional[datetime]) -> int:
|
||||
return self.silence_days(last_contact_at)
|
||||
|
||||
def _status_weight(self, status: Optional[str]) -> float:
|
||||
return self.status_weight(status)
|
||||
|
||||
def _grade(self, score: float) -> str:
|
||||
return self.grade(score)
|
||||
|
||||
def _suggestion(self, grade: str, customer: Customer) -> str:
|
||||
return self.suggestion(grade, self._silence_days(customer.last_contact_at), customer.status)
|
||||
|
||||
async def _calc_business_value(self, customer_id: str) -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(func.sum(Quotation.total)).where(
|
||||
and_(
|
||||
Quotation.customer_id == customer_id,
|
||||
Quotation.status.in_(["sent", "accepted"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
total_value = result.scalar() or 0
|
||||
return self.calc_business_value_score(total_value)
|
||||
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
from typing import Dict, Optional
|
||||
from datetime import datetime
|
||||
from app.config import settings
|
||||
from app.core.redis import get_redis
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FALLBACK_RATES: Dict[str, Dict[str, float]] = {
|
||||
"USD": {"CNY": 7.24, "EUR": 0.92, "GBP": 0.79, "JPY": 151.50, "KRW": 1320.00, "AUD": 1.52, "CAD": 1.37, "INR": 83.50, "BRL": 5.10, "RUB": 92.00},
|
||||
"CNY": {"USD": 0.138, "EUR": 0.127, "GBP": 0.109, "JPY": 20.93, "KRW": 182.32, "AUD": 0.21, "CAD": 0.19},
|
||||
"EUR": {"USD": 1.09, "CNY": 7.85, "GBP": 0.86, "JPY": 164.50, "KRW": 1435.00},
|
||||
"GBP": {"USD": 1.27, "CNY": 9.15, "EUR": 1.16, "JPY": 192.00},
|
||||
}
|
||||
|
||||
CACHE_TTL = 21600
|
||||
|
||||
|
||||
class ExchangeRateService:
|
||||
def __init__(self):
|
||||
self._rates_cache: Optional[Dict] = None
|
||||
self._cache_time: Optional[datetime] = None
|
||||
|
||||
async def get_rate(self, from_currency: str, to_currency: str) -> Optional[float]:
|
||||
from_currency = from_currency.upper()
|
||||
to_currency = to_currency.upper()
|
||||
if from_currency == to_currency:
|
||||
return 1.0
|
||||
|
||||
rates = await self._get_all_rates(from_currency)
|
||||
if rates and to_currency in rates:
|
||||
return rates[to_currency]
|
||||
|
||||
base_rates = FALLBACK_RATES.get(from_currency, {})
|
||||
return base_rates.get(to_currency)
|
||||
|
||||
async def convert(self, from_currency: str, to_currency: str, amount: float = 1.0) -> Optional[float]:
|
||||
rate = await self.get_rate(from_currency, to_currency)
|
||||
if rate is None:
|
||||
return None
|
||||
return round(amount * rate, 2)
|
||||
|
||||
async def get_all_rates(self, base: str = "USD") -> Dict[str, float]:
|
||||
base = base.upper()
|
||||
rates = await self._get_all_rates(base)
|
||||
if rates:
|
||||
return rates
|
||||
return FALLBACK_RATES.get(base, {})
|
||||
|
||||
async def _get_all_rates(self, base: str) -> Optional[Dict[str, float]]:
|
||||
cached = await self._get_from_cache(base)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
rates = None
|
||||
for fetcher in [self._fetch_from_frankfurter, self._fetch_from_exchangerate_api]:
|
||||
try:
|
||||
rates = await fetcher(base)
|
||||
if rates:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Exchange rate fetcher failed: {e}")
|
||||
|
||||
if rates:
|
||||
await self._set_cache(base, rates)
|
||||
|
||||
return rates
|
||||
|
||||
async def _fetch_from_frankfurter(self, base: str) -> Optional[Dict[str, float]]:
|
||||
supported = ["USD", "EUR", "GBP", "CNY", "JPY", "KRW", "AUD", "CAD", "INR", "BRL"]
|
||||
if base not in supported:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"https://api.frankfurter.app/latest",
|
||||
params={"from": base, "to": ",".join(supported)},
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
return data.get("rates")
|
||||
except Exception as e:
|
||||
logger.warning(f"Frankfurter API failed: {e}")
|
||||
return None
|
||||
|
||||
async def _fetch_from_exchangerate_api(self, base: str) -> Optional[Dict[str, float]]:
|
||||
if not settings.EXCHANGE_RATE_API_KEY:
|
||||
return None
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"https://v6.exchangerate-api.com/v6/{settings.EXCHANGE_RATE_API_KEY}/latest/{base}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get("result") == "success":
|
||||
return data.get("conversion_rates")
|
||||
except Exception as e:
|
||||
logger.warning(f"ExchangeRate-API failed: {e}")
|
||||
return None
|
||||
|
||||
async def _get_from_cache(self, base: str) -> Optional[Dict[str, float]]:
|
||||
try:
|
||||
r = await get_redis()
|
||||
data = await r.get(f"exchange_rate:{base}")
|
||||
if data:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Redis cache miss for {base}: {e}")
|
||||
return None
|
||||
|
||||
async def _set_cache(self, base: str, rates: Dict[str, float]):
|
||||
try:
|
||||
r = await get_redis()
|
||||
await r.setex(f"exchange_rate:{base}", CACHE_TTL, json.dumps(rates))
|
||||
except Exception as e:
|
||||
logger.debug(f"Redis cache set failed for {base}: {e}")
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import List, Dict, Any
|
||||
import csv
|
||||
import io
|
||||
|
||||
|
||||
def export_customers_csv(customers: List[Dict[str, Any]]) -> bytes:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Name", "Company", "Country", "Phone", "Email", "Status", "Last Contact"])
|
||||
for c in customers:
|
||||
writer.writerow([
|
||||
c.get("name", ""),
|
||||
c.get("company", ""),
|
||||
c.get("country", ""),
|
||||
c.get("phone", ""),
|
||||
c.get("email", ""),
|
||||
c.get("status", ""),
|
||||
c.get("last_contact_at", ""),
|
||||
])
|
||||
return output.getvalue().encode("utf-8-sig")
|
||||
|
||||
|
||||
def export_quotations_csv(quotations: List[Dict[str, Any]]) -> bytes:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Title", "Customer", "Currency", "Subtotal", "Total", "Status", "Date"])
|
||||
for q in quotations:
|
||||
writer.writerow([
|
||||
q.get("title", ""),
|
||||
q.get("customer_name", ""),
|
||||
q.get("currency", "USD"),
|
||||
q.get("subtotal", 0),
|
||||
q.get("total", 0),
|
||||
q.get("status", ""),
|
||||
q.get("created_at", ""),
|
||||
])
|
||||
return output.getvalue().encode("utf-8-sig")
|
||||
@@ -0,0 +1,396 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, desc
|
||||
from app.models.followup import FollowupStrategy, FollowupLog
|
||||
from app.models.customer import Customer
|
||||
from app.models.notification import Notification
|
||||
from app.ai.router import get_ai_router
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_STRATEGIES = [
|
||||
{
|
||||
"name": "温和提醒",
|
||||
"description": "沉默3-5天,健康分50-79 — 温和提醒",
|
||||
"trigger_condition": {
|
||||
"min_silence_days": 3,
|
||||
"max_silence_days": 5,
|
||||
"min_health_score": 50,
|
||||
"max_health_score": 79,
|
||||
"status_filter": ["lead", "negotiating"],
|
||||
},
|
||||
"channel": "whatsapp",
|
||||
"ai_prompt_template": "You are a professional export sales assistant. Write a gentle follow-up message to a customer who hasn't responded in {silence_days} days. Customer name: {customer_name}. Tone: warm but professional. Keep under 100 words. Suggest checking if they need any further information about the product.",
|
||||
"priority": 1,
|
||||
},
|
||||
{
|
||||
"name": "价值提供",
|
||||
"description": "沉默6-10天,健康分30-49 — 推送价值信息",
|
||||
"trigger_condition": {
|
||||
"min_silence_days": 6,
|
||||
"max_silence_days": 10,
|
||||
"min_health_score": 30,
|
||||
"max_health_score": 49,
|
||||
"status_filter": ["lead", "negotiating"],
|
||||
},
|
||||
"channel": "email",
|
||||
"ai_prompt_template": "You are a professional export sales assistant. Write a follow-up email to a customer who hasn't responded in {silence_days} days. Customer: {customer_name}. Share some valuable industry news, new product catalog highlights, or certification updates to rekindle interest. Keep under 150 words.",
|
||||
"priority": 2,
|
||||
},
|
||||
{
|
||||
"name": "重新激活",
|
||||
"description": "沉默11+天,健康分<30 — 紧急重新激活",
|
||||
"trigger_condition": {
|
||||
"min_silence_days": 11,
|
||||
"max_silence_days": 999,
|
||||
"min_health_score": 0,
|
||||
"max_health_score": 29,
|
||||
"status_filter": ["lead", "negotiating"],
|
||||
},
|
||||
"channel": "email",
|
||||
"ai_prompt_template": "You are a professional export sales assistant. Write a re-engagement email to a customer who has been silent for {silence_days} days. Customer: {customer_name}. Offer a limited-time discount, new product launch info, or a holiday greeting. Create a sense of urgency without being pushy. Keep under 150 words.",
|
||||
"priority": 3,
|
||||
},
|
||||
{
|
||||
"name": "促进决策",
|
||||
"description": "客户有回复但未成交,健康分60+ — 促进成交",
|
||||
"trigger_condition": {
|
||||
"min_silence_days": 2,
|
||||
"max_silence_days": 7,
|
||||
"min_health_score": 60,
|
||||
"max_health_score": 100,
|
||||
"status_filter": ["negotiating"],
|
||||
},
|
||||
"channel": "whatsapp",
|
||||
"ai_prompt_template": "You are a professional export sales assistant. The customer {customer_name} has shown interest but hasn't placed an order yet. Write a message sharing a success story, a limited-time offer, or highlighting what makes your product different from competitors. Keep under 120 words. Tone: confident and helpful.",
|
||||
"priority": 0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class FollowupEngine:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.ai = get_ai_router()
|
||||
self.health_service = CustomerHealthService(db)
|
||||
|
||||
async def ensure_default_strategies(self):
|
||||
result = await self.db.execute(
|
||||
select(FollowupStrategy).limit(1)
|
||||
)
|
||||
if result.scalar_one_or_none():
|
||||
return
|
||||
for s in DEFAULT_STRATEGIES:
|
||||
strategy = FollowupStrategy(
|
||||
name=s["name"],
|
||||
description=s["description"],
|
||||
trigger_condition=s["trigger_condition"],
|
||||
channel=s["channel"],
|
||||
ai_prompt_template=s["ai_prompt_template"],
|
||||
priority=s["priority"],
|
||||
)
|
||||
self.db.add(strategy)
|
||||
await self.db.flush()
|
||||
logger.info(f"Created {len(DEFAULT_STRATEGIES)} default followup strategies")
|
||||
|
||||
async def get_strategies(self) -> List[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(FollowupStrategy).order_by(FollowupStrategy.priority)
|
||||
)
|
||||
strategies = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(s.id),
|
||||
"name": s.name,
|
||||
"description": s.description,
|
||||
"trigger_condition": s.trigger_condition,
|
||||
"channel": s.channel,
|
||||
"priority": s.priority,
|
||||
"is_active": s.is_active,
|
||||
}
|
||||
for s in strategies
|
||||
]
|
||||
|
||||
async def evaluate_customer(self, user_id: str, customer: Customer) -> Optional[Dict[str, Any]]:
|
||||
health = await self.health_service.get_customer_health(user_id, str(customer.id))
|
||||
if not health:
|
||||
return None
|
||||
|
||||
silence_days = health["dimensions"]["silence"]["days"]
|
||||
health_score = health["total_score"]
|
||||
|
||||
strategies_result = await self.db.execute(
|
||||
select(FollowupStrategy).where(
|
||||
and_(
|
||||
FollowupStrategy.is_active == True,
|
||||
)
|
||||
).order_by(FollowupStrategy.priority)
|
||||
)
|
||||
strategies = strategies_result.scalars().all()
|
||||
|
||||
for strategy in strategies:
|
||||
cond = strategy.trigger_condition
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
if silence_days < cond.get("min_silence_days", 0):
|
||||
continue
|
||||
if silence_days > cond.get("max_silence_days", 999):
|
||||
continue
|
||||
if health_score < cond.get("min_health_score", 0):
|
||||
continue
|
||||
if health_score > cond.get("max_health_score", 100):
|
||||
continue
|
||||
if cond.get("status_filter") and customer.status not in cond["status_filter"]:
|
||||
continue
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(FollowupLog).where(
|
||||
and_(
|
||||
FollowupLog.customer_id == customer.id,
|
||||
FollowupLog.strategy_id == strategy.id,
|
||||
FollowupLog.status.in_(["pending", "sent"]),
|
||||
FollowupLog.created_at > datetime.utcnow() - timedelta(days=7),
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
continue
|
||||
|
||||
return {
|
||||
"strategy": strategy,
|
||||
"silence_days": silence_days,
|
||||
"health_score": health_score,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
async def generate_followup_content(self, strategy: FollowupStrategy, customer: Customer, silence_days: int) -> str:
|
||||
try:
|
||||
prompt = strategy.ai_prompt_template.format(
|
||||
customer_name=customer.name,
|
||||
silence_days=silence_days,
|
||||
company=customer.company or "",
|
||||
)
|
||||
result = await self.ai.execute("marketing", "generate_marketing",
|
||||
{"name": customer.name, "description": prompt},
|
||||
customer.country or "US",
|
||||
"professional",
|
||||
"en"
|
||||
)
|
||||
return result.get("content", "")
|
||||
except Exception as e:
|
||||
logger.warning(f"AI content generation failed: {e}")
|
||||
return f"Hi {customer.name}, just checking in to see if you need any further information about our products. Looking forward to hearing from you!"
|
||||
|
||||
async def create_followup_log(self, user_id: str, customer: Customer,
|
||||
strategy: FollowupStrategy, silence_days: int,
|
||||
health_score: int, content: str) -> FollowupLog:
|
||||
log = FollowupLog(
|
||||
user_id=user_id,
|
||||
customer_id=customer.id,
|
||||
strategy_id=strategy.id,
|
||||
status="pending",
|
||||
channel=strategy.channel,
|
||||
ai_generated_content=content,
|
||||
content=content,
|
||||
health_score_at_time=health_score,
|
||||
silence_days_at_time=silence_days,
|
||||
)
|
||||
self.db.add(log)
|
||||
await self.db.flush()
|
||||
return log
|
||||
|
||||
async def scan_and_followup(self) -> Dict[str, Any]:
|
||||
await self.ensure_default_strategies()
|
||||
|
||||
customers_result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
Customer.status.in_(["lead", "negotiating"])
|
||||
)
|
||||
)
|
||||
customers = customers_result.scalars().all()
|
||||
|
||||
processed = 0
|
||||
notifications_sent = 0
|
||||
logs_created = 0
|
||||
|
||||
for customer in customers:
|
||||
try:
|
||||
result = await self.evaluate_customer(str(customer.user_id), customer)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
content = await self.generate_followup_content(
|
||||
result["strategy"], customer, result["silence_days"]
|
||||
)
|
||||
log = await self.create_followup_log(
|
||||
str(customer.user_id), customer,
|
||||
result["strategy"], result["silence_days"],
|
||||
result["health_score"], content,
|
||||
)
|
||||
|
||||
title = f"跟进提醒: {customer.name}"
|
||||
notify_content = f"{result['strategy'].name} — {content[:80]}..."
|
||||
n = Notification(
|
||||
user_id=customer.user_id,
|
||||
title=title,
|
||||
content=notify_content,
|
||||
notification_type="followup",
|
||||
reference_type="customer",
|
||||
reference_id=str(customer.id),
|
||||
)
|
||||
self.db.add(n)
|
||||
|
||||
processed += 1
|
||||
logs_created += 1
|
||||
notifications_sent += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Followup scan failed for customer {customer.id}: {e}")
|
||||
continue
|
||||
|
||||
if processed > 0:
|
||||
await self.db.flush()
|
||||
logger.info(f"Followup scan: {processed} customers matched, {logs_created} logs, {notifications_sent} notifications")
|
||||
|
||||
return {
|
||||
"customers_scanned": len(customers),
|
||||
"followups_created": logs_created,
|
||||
"notifications_sent": notifications_sent,
|
||||
}
|
||||
|
||||
async def get_pending_followups(self, user_id: str, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(FollowupLog).where(
|
||||
and_(
|
||||
FollowupLog.user_id == user_id,
|
||||
FollowupLog.status == "pending",
|
||||
)
|
||||
).order_by(FollowupLog.created_at.desc()).offset(
|
||||
(page - 1) * size
|
||||
).limit(size)
|
||||
|
||||
count_q = select(FollowupLog).where(
|
||||
and_(
|
||||
FollowupLog.user_id == user_id,
|
||||
FollowupLog.status == "pending",
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
count_result = await self.db.execute(count_q)
|
||||
total = len(count_result.scalars().all())
|
||||
|
||||
items = []
|
||||
for log in logs:
|
||||
customer_result = await self.db.execute(
|
||||
select(Customer).where(Customer.id == log.customer_id)
|
||||
)
|
||||
customer = customer_result.scalar_one_or_none()
|
||||
items.append({
|
||||
"id": str(log.id),
|
||||
"customer_id": str(log.customer_id),
|
||||
"customer_name": customer.name if customer else "Unknown",
|
||||
"strategy": "跟进",
|
||||
"channel": log.channel,
|
||||
"content": log.content,
|
||||
"ai_generated_content": log.ai_generated_content,
|
||||
"health_score": log.health_score_at_time,
|
||||
"silence_days": log.silence_days_at_time,
|
||||
"status": log.status,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
})
|
||||
|
||||
return {"items": items, "total": total, "page": page, "size": size}
|
||||
|
||||
async def get_followup_logs(self, user_id: str, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(FollowupLog).where(
|
||||
FollowupLog.user_id == user_id
|
||||
).order_by(FollowupLog.created_at.desc()).offset(
|
||||
(page - 1) * size
|
||||
).limit(size)
|
||||
|
||||
count_q = select(FollowupLog).where(FollowupLog.user_id == user_id)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
logs = result.scalars().all()
|
||||
|
||||
count_result = await self.db.execute(count_q)
|
||||
total = len(count_result.scalars().all())
|
||||
|
||||
items = []
|
||||
for log in logs:
|
||||
customer_result = await self.db.execute(
|
||||
select(Customer).where(Customer.id == log.customer_id)
|
||||
)
|
||||
customer = customer_result.scalar_one_or_none()
|
||||
items.append({
|
||||
"id": str(log.id),
|
||||
"customer_id": str(log.customer_id),
|
||||
"customer_name": customer.name if customer else "Unknown",
|
||||
"channel": log.channel,
|
||||
"content": log.content,
|
||||
"ai_generated_content": log.ai_generated_content,
|
||||
"user_edited_content": log.user_edited_content,
|
||||
"status": log.status,
|
||||
"health_score": log.health_score_at_time,
|
||||
"silence_days": log.silence_days_at_time,
|
||||
"sent_at": log.sent_at.isoformat() if log.sent_at else None,
|
||||
"replied_at": log.replied_at.isoformat() if log.replied_at else None,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
})
|
||||
|
||||
return {"items": items, "total": total, "page": page, "size": size}
|
||||
|
||||
async def mark_sent(self, user_id: str, log_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(FollowupLog).where(
|
||||
and_(FollowupLog.id == log_id, FollowupLog.user_id == user_id)
|
||||
)
|
||||
)
|
||||
log = result.scalar_one_or_none()
|
||||
if not log:
|
||||
return False
|
||||
log.status = "sent"
|
||||
log.sent_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def mark_edited(self, user_id: str, log_id: str, edited_text: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(FollowupLog).where(
|
||||
and_(FollowupLog.id == log_id, FollowupLog.user_id == user_id)
|
||||
)
|
||||
)
|
||||
log = result.scalar_one_or_none()
|
||||
if not log:
|
||||
return False
|
||||
log.user_edited_content = edited_text
|
||||
log.content = edited_text
|
||||
log.status = "sent"
|
||||
log.sent_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def get_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
logs_result = await self.db.execute(
|
||||
select(FollowupLog).where(FollowupLog.user_id == user_id)
|
||||
)
|
||||
all_logs = logs_result.scalars().all()
|
||||
total = len(all_logs)
|
||||
pending = sum(1 for l in all_logs if l.status == "pending")
|
||||
sent = sum(1 for l in all_logs if l.status == "sent")
|
||||
replied = sum(1 for l in all_logs if l.status == "replied")
|
||||
|
||||
return {
|
||||
"total_followups": total,
|
||||
"pending": pending,
|
||||
"sent": sent,
|
||||
"replied": replied,
|
||||
"completion_rate": round(sent / total * 100, 1) if total > 0 else 0,
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import openpyxl
|
||||
HAS_OPENPYXL = True
|
||||
except ImportError:
|
||||
HAS_OPENPYXL = False
|
||||
logger.warning("openpyxl not installed, XLSX import disabled")
|
||||
|
||||
|
||||
REQUIRED_COLUMNS = {"name"}
|
||||
OPTIONAL_COLUMNS = {
|
||||
"company", "country", "phone", "email", "whatsapp_id",
|
||||
"source", "tags", "notes", "status", "estimated_value",
|
||||
}
|
||||
|
||||
|
||||
class ImportService:
|
||||
@staticmethod
|
||||
def parse_xlsx(file_bytes: bytes) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
if not HAS_OPENPYXL:
|
||||
return [], ["openpyxl not installed"]
|
||||
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_bytes), read_only=True)
|
||||
ws = wb.active
|
||||
rows = list(ws.iter_rows(values_only=True))
|
||||
if not rows:
|
||||
return [], ["Empty file"]
|
||||
|
||||
headers = [str(h).strip().lower() if h else "" for h in rows[0]]
|
||||
missing = REQUIRED_COLUMNS - set(headers)
|
||||
if missing:
|
||||
return [], [f"Missing required columns: {', '.join(missing)}"]
|
||||
|
||||
records = []
|
||||
errors = []
|
||||
for i, row in enumerate(rows[1:], 2):
|
||||
if all(v is None or str(v).strip() == "" for v in row):
|
||||
continue
|
||||
record = {}
|
||||
for j, val in enumerate(row):
|
||||
if j < len(headers) and headers[j]:
|
||||
record[headers[j]] = str(val).strip() if val is not None else ""
|
||||
if not record.get("name"):
|
||||
errors.append(f"Row {i}: missing name")
|
||||
continue
|
||||
records.append(record)
|
||||
|
||||
return records, errors
|
||||
|
||||
except Exception as e:
|
||||
return [], [f"Parse error: {str(e)}"]
|
||||
|
||||
@staticmethod
|
||||
def parse_csv(file_bytes: bytes) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
try:
|
||||
text = file_bytes.decode("utf-8-sig")
|
||||
reader = csv.DictReader(io.StringIO(text))
|
||||
if not reader.fieldnames:
|
||||
return [], ["Empty or invalid CSV"]
|
||||
|
||||
headers = [h.strip().lower() for h in reader.fieldnames]
|
||||
missing = REQUIRED_COLUMNS - set(headers)
|
||||
if missing:
|
||||
return [], [f"Missing required columns: {', '.join(missing)}"]
|
||||
|
||||
records = []
|
||||
errors = []
|
||||
for i, row in enumerate(reader, 2):
|
||||
cleaned = {}
|
||||
for k, v in row.items():
|
||||
key = k.strip().lower()
|
||||
if key:
|
||||
cleaned[key] = v.strip() if v else ""
|
||||
if not cleaned.get("name"):
|
||||
errors.append(f"Row {i}: missing name")
|
||||
continue
|
||||
cleaned = {k: v for k, v in cleaned.items() if k in REQUIRED_COLUMNS | OPTIONAL_COLUMNS}
|
||||
records.append(cleaned)
|
||||
|
||||
return records, errors
|
||||
|
||||
except Exception as e:
|
||||
return [], [f"Parse error: {str(e)}"]
|
||||
|
||||
@staticmethod
|
||||
def validate_records(records: List[Dict]) -> Tuple[List[Dict], List[str]]:
|
||||
valid = []
|
||||
errors = []
|
||||
for i, r in enumerate(records, 1):
|
||||
if r.get("status") and r["status"] not in ("lead", "negotiating", "customer", "lost", "archived"):
|
||||
errors.append(f"Row {i}: invalid status '{r['status']}'")
|
||||
continue
|
||||
if r.get("phone") and not r["phone"].strip():
|
||||
r.pop("phone", None)
|
||||
r.setdefault("status", "lead")
|
||||
r.setdefault("source", "import")
|
||||
r.setdefault("tags", [])
|
||||
if isinstance(r.get("tags"), str):
|
||||
r["tags"] = [t.strip() for t in r["tags"].split(",") if t.strip()]
|
||||
valid.append(r)
|
||||
return valid, errors
|
||||
|
||||
|
||||
import_service = ImportService()
|
||||
@@ -16,13 +16,14 @@ class MarketingService:
|
||||
style: str = "professional",
|
||||
language: str = "en",
|
||||
count: int = 3,
|
||||
preference_context: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
results = []
|
||||
styles = self._get_style_variants(style, count)
|
||||
|
||||
for s in styles:
|
||||
try:
|
||||
result = await self.ai.marketing(product_info, target, s, language)
|
||||
result = await self.ai.marketing(product_info, target, s, language, preference_context)
|
||||
results.append({
|
||||
"content": result.get("content", ""),
|
||||
"style": s,
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.preference import MarketingEffect
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketingEffectService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def track_event(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
product_id: Optional[str] = None,
|
||||
product_name: Optional[str] = None,
|
||||
channel: str = "copy",
|
||||
event_type: str = "copy",
|
||||
target_audience: str = "",
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> Dict[str, Any]:
|
||||
content_hash = hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
event = MarketingEffect(
|
||||
user_id=user_id,
|
||||
content_hash=content_hash,
|
||||
product_id=product_id,
|
||||
product_name=product_name,
|
||||
channel=channel,
|
||||
event_type=event_type,
|
||||
target_audience=target_audience,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.db.add(event)
|
||||
await self.db.flush()
|
||||
|
||||
return {
|
||||
"id": str(event.id),
|
||||
"event_type": event_type,
|
||||
"content_hash": content_hash,
|
||||
}
|
||||
|
||||
async def get_effects(
|
||||
self, user_id: str, page: int = 1, size: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
query = (
|
||||
select(MarketingEffect)
|
||||
.where(MarketingEffect.user_id == user_id)
|
||||
.order_by(MarketingEffect.created_at.desc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
count_query = select(func.count(MarketingEffect.id)).where(
|
||||
MarketingEffect.user_id == user_id
|
||||
)
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(e.id),
|
||||
"product_name": e.product_name,
|
||||
"channel": e.channel,
|
||||
"event_type": e.event_type,
|
||||
"target_audience": e.target_audience,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in events
|
||||
],
|
||||
"total": total.scalar() or 0,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def get_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
today = datetime.utcnow().date()
|
||||
week_ago = today - timedelta(days=7)
|
||||
|
||||
total_query = select(func.count(MarketingEffect.id)).where(
|
||||
MarketingEffect.user_id == user_id
|
||||
)
|
||||
today_query = select(func.count(MarketingEffect.id)).where(
|
||||
and_(
|
||||
MarketingEffect.user_id == user_id,
|
||||
func.date(MarketingEffect.created_at) == today,
|
||||
)
|
||||
)
|
||||
week_query = select(func.count(MarketingEffect.id)).where(
|
||||
and_(
|
||||
MarketingEffect.user_id == user_id,
|
||||
func.date(MarketingEffect.created_at) >= week_ago,
|
||||
)
|
||||
)
|
||||
copy_query = select(func.count(MarketingEffect.id)).where(
|
||||
and_(
|
||||
MarketingEffect.user_id == user_id,
|
||||
MarketingEffect.event_type == "copy",
|
||||
)
|
||||
)
|
||||
send_query = select(func.count(MarketingEffect.id)).where(
|
||||
and_(
|
||||
MarketingEffect.user_id == user_id,
|
||||
MarketingEffect.event_type == "send",
|
||||
)
|
||||
)
|
||||
|
||||
totals = await self.db.execute(total_query)
|
||||
todays = await self.db.execute(today_query)
|
||||
weeks = await self.db.execute(week_query)
|
||||
copies = await self.db.execute(copy_query)
|
||||
sends = await self.db.execute(send_query)
|
||||
|
||||
return {
|
||||
"total_events": totals.scalar() or 0,
|
||||
"today": todays.scalar() or 0,
|
||||
"this_week": weeks.scalar() or 0,
|
||||
"copy_count": copies.scalar() or 0,
|
||||
"send_count": sends.scalar() or 0,
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.notification import Notification
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def list_notifications(
|
||||
self, user_id: str, page: int = 1, size: int = 20, unread_only: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
query = select(Notification).where(Notification.user_id == user_id)
|
||||
if unread_only:
|
||||
query = query.where(Notification.is_read == False)
|
||||
query = query.order_by(Notification.created_at.desc()).offset(
|
||||
(page - 1) * size
|
||||
).limit(size)
|
||||
|
||||
count_query = select(func.count(Notification.id)).where(
|
||||
Notification.user_id == user_id
|
||||
)
|
||||
if unread_only:
|
||||
count_query = count_query.where(Notification.is_read == False)
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
notifications = result.scalars().all()
|
||||
|
||||
return {
|
||||
"items": [
|
||||
{
|
||||
"id": str(n.id),
|
||||
"title": n.title,
|
||||
"content": n.content,
|
||||
"type": n.notification_type,
|
||||
"reference_type": n.reference_type,
|
||||
"reference_id": n.reference_id,
|
||||
"is_read": n.is_read,
|
||||
"created_at": n.created_at.isoformat() if n.created_at else None,
|
||||
}
|
||||
for n in notifications
|
||||
],
|
||||
"total": total.scalar() or 0,
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def get_unread_count(self, user_id: str) -> int:
|
||||
result = await self.db.execute(
|
||||
select(func.count(Notification.id)).where(
|
||||
and_(Notification.user_id == user_id, Notification.is_read == False)
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def mark_read(self, user_id: str, notification_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Notification).where(
|
||||
and_(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
n = result.scalar_one_or_none()
|
||||
if not n:
|
||||
return False
|
||||
n.is_read = True
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def mark_all_read(self, user_id: str) -> int:
|
||||
result = await self.db.execute(
|
||||
select(Notification).where(
|
||||
and_(Notification.user_id == user_id, Notification.is_read == False)
|
||||
)
|
||||
)
|
||||
notifications = result.scalars().all()
|
||||
for n in notifications:
|
||||
n.is_read = True
|
||||
await self.db.flush()
|
||||
return len(notifications)
|
||||
|
||||
async def delete_notification(self, user_id: str, notification_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Notification).where(
|
||||
and_(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
n = result.scalar_one_or_none()
|
||||
if not n:
|
||||
return False
|
||||
await self.db.delete(n)
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def create_notification(
|
||||
db: AsyncSession, user_id: str, title: str, content: str,
|
||||
notification_type: str = "system",
|
||||
reference_type: str = None, reference_id: str = None,
|
||||
):
|
||||
n = Notification(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
content=content,
|
||||
notification_type=notification_type,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
)
|
||||
db.add(n)
|
||||
await db.flush()
|
||||
return n
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from app.models.user import User, Product
|
||||
from app.services.marketing import MarketingService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OnboardingService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def check_status(self, user_id: str) -> Dict[str, Any]:
|
||||
product_count = await self.db.execute(
|
||||
select(func.count(Product.id)).where(
|
||||
Product.user_id == user_id, Product.is_active == True
|
||||
)
|
||||
)
|
||||
has_products = (product_count.scalar() or 0) > 0
|
||||
return {"onboarded": has_products}
|
||||
|
||||
async def generate_first_product(
|
||||
self, user_id: str, name: str, description: str, category: str = "", target: str = "US importers"
|
||||
) -> Dict[str, Any]:
|
||||
product = Product(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
category=category or "general",
|
||||
is_active=True,
|
||||
)
|
||||
self.db.add(product)
|
||||
await self.db.flush()
|
||||
|
||||
mkt = MarketingService()
|
||||
try:
|
||||
content = await mkt.generate(
|
||||
product_name=name,
|
||||
description=description,
|
||||
category=category or "general",
|
||||
target=target,
|
||||
style="professional",
|
||||
count=3,
|
||||
language="en",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Onboarding content generation failed: {e}")
|
||||
content = [f"Check out our {name} - {description[:100]}..."]
|
||||
|
||||
try:
|
||||
keywords_result = await mkt.generate_keywords(
|
||||
product_name=name, description=description, category=category or "general"
|
||||
)
|
||||
keywords = keywords_result if isinstance(keywords_result, list) else []
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword generation failed: {e}")
|
||||
keywords = []
|
||||
|
||||
product.keywords = keywords[:10]
|
||||
await self.db.flush()
|
||||
|
||||
return {
|
||||
"product": {
|
||||
"id": str(product.id),
|
||||
"name": product.name,
|
||||
"description": product.description,
|
||||
"category": product.category,
|
||||
"keywords": keywords[:10],
|
||||
},
|
||||
"generated_content": content,
|
||||
"keywords": keywords[:10],
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
import hmac
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.subscription import Subscription
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLANS = {
|
||||
"free": {"price": 0, "duration_days": None},
|
||||
"pro": {"price": 99, "duration_days": 30},
|
||||
"enterprise": {"price": 399, "duration_days": 30},
|
||||
}
|
||||
|
||||
|
||||
class PaymentService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_plans(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plans": [
|
||||
{
|
||||
"id": "free",
|
||||
"name": "免费版",
|
||||
"price": 0,
|
||||
"features": [
|
||||
"1 个产品",
|
||||
"20 次翻译/天",
|
||||
"5 个客户",
|
||||
"基础回复建议",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pro",
|
||||
"name": "Pro 版",
|
||||
"price": 99,
|
||||
"features": [
|
||||
"10 个产品",
|
||||
"无限翻译",
|
||||
"50 个客户",
|
||||
"跟进提醒",
|
||||
"报价单生成",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "enterprise",
|
||||
"name": "企业版",
|
||||
"price": 399,
|
||||
"features": [
|
||||
"无限产品",
|
||||
"多人协作",
|
||||
"品牌报价单",
|
||||
"专属语料训练",
|
||||
"API 接入",
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
async def get_current_subscription(self, user_id: str) -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == user_id,
|
||||
Subscription.status == "active",
|
||||
).order_by(Subscription.created_at.desc()).limit(1)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"plan": user.tier if user else "free",
|
||||
"status": sub.status if sub else "active",
|
||||
"expires_at": sub.expires_at.isoformat() if sub and sub.expires_at else None,
|
||||
"auto_renew": sub.auto_renew if sub else False,
|
||||
}
|
||||
|
||||
async def create_order(self, user_id: str, plan: str) -> Dict[str, Any]:
|
||||
if plan not in PLANS:
|
||||
raise ValueError(f"Invalid plan: {plan}")
|
||||
|
||||
plan_info = PLANS[plan]
|
||||
if plan_info["price"] == 0:
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = plan
|
||||
await self.db.flush()
|
||||
return {"status": "ok", "plan": plan, "amount": 0}
|
||||
|
||||
from app.config import settings
|
||||
order_id = f"ORD{datetime.utcnow().strftime('%Y%m%d%H%M%S')}{user_id[-6:]}"
|
||||
|
||||
sub = Subscription(
|
||||
user_id=user_id,
|
||||
plan=plan,
|
||||
status="pending",
|
||||
amount=plan_info["price"],
|
||||
payment_id=order_id,
|
||||
)
|
||||
self.db.add(sub)
|
||||
await self.db.flush()
|
||||
|
||||
pay_params = {
|
||||
"appId": settings.WECHAT_APP_ID or "",
|
||||
"timeStamp": str(int(datetime.utcnow().timestamp())),
|
||||
"nonceStr": hashlib.md5(order_id.encode()).hexdigest()[:16],
|
||||
"package": f"prepay_id={order_id}",
|
||||
"signType": "MD5",
|
||||
}
|
||||
sign_str = "&".join(f"{k}={v}" for k, v in sorted(pay_params.items()))
|
||||
sign_str += f"&key={settings.SECRET_KEY}"
|
||||
pay_params["paySign"] = hashlib.md5(sign_str.encode()).hexdigest().upper()
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"order_id": order_id,
|
||||
"plan": plan,
|
||||
"amount": plan_info["price"],
|
||||
"currency": "CNY",
|
||||
"pay_params": pay_params,
|
||||
}
|
||||
|
||||
async def handle_payment_callback(self, payment_id: str, success: bool) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Subscription).where(Subscription.payment_id == payment_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if not sub:
|
||||
return False
|
||||
|
||||
if success:
|
||||
sub.status = "active"
|
||||
sub.started_at = datetime.utcnow()
|
||||
sub.expires_at = datetime.utcnow() + timedelta(days=PLANS[sub.plan]["duration_days"])
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == sub.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user:
|
||||
user.tier = sub.plan
|
||||
else:
|
||||
sub.status = "failed"
|
||||
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
|
||||
payment_service = PaymentService
|
||||
@@ -0,0 +1,229 @@
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
HAS_WEASYPRINT = True
|
||||
except ImportError:
|
||||
HAS_WEASYPRINT = False
|
||||
logger.warning("weasyprint not installed, PDF generation disabled")
|
||||
|
||||
|
||||
QUOTATION_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<style>
|
||||
@page {{
|
||||
size: A4;
|
||||
margin: 2cm;
|
||||
}}
|
||||
body {{
|
||||
font-family: 'Helvetica Neue', Arial, sans-serif;
|
||||
font-size: 12pt;
|
||||
color: #333;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.header {{
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
border-bottom: 2px solid #1890ff;
|
||||
padding-bottom: 20px;
|
||||
}}
|
||||
.header h1 {{
|
||||
font-size: 24pt;
|
||||
color: #1890ff;
|
||||
margin: 0;
|
||||
}}
|
||||
.header .number {{
|
||||
font-size: 14pt;
|
||||
color: #666;
|
||||
}}
|
||||
.info-grid {{
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 30px;
|
||||
}}
|
||||
.info-block {{
|
||||
width: 48%;
|
||||
}}
|
||||
.info-block h3 {{
|
||||
font-size: 11pt;
|
||||
color: #1890ff;
|
||||
margin-bottom: 8px;
|
||||
border-bottom: 1px solid #e8e8e8;
|
||||
padding-bottom: 4px;
|
||||
}}
|
||||
.info-block p {{
|
||||
margin: 4px 0;
|
||||
font-size: 10pt;
|
||||
}}
|
||||
table {{
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
margin-bottom: 30px;
|
||||
}}
|
||||
th {{
|
||||
background: #1890ff;
|
||||
color: white;
|
||||
padding: 10px 8px;
|
||||
text-align: left;
|
||||
font-size: 10pt;
|
||||
}}
|
||||
td {{
|
||||
padding: 8px;
|
||||
border-bottom: 1px solid #e8e8e8;
|
||||
font-size: 10pt;
|
||||
}}
|
||||
.amount-row td {{
|
||||
text-align: right;
|
||||
padding: 4px 8px;
|
||||
border: none;
|
||||
}}
|
||||
.total-row td {{
|
||||
font-weight: bold;
|
||||
font-size: 12pt;
|
||||
border-top: 2px solid #333;
|
||||
}}
|
||||
.terms {{
|
||||
margin-top: 30px;
|
||||
padding-top: 15px;
|
||||
border-top: 1px solid #e8e8e8;
|
||||
}}
|
||||
.terms h3 {{
|
||||
font-size: 11pt;
|
||||
color: #1890ff;
|
||||
}}
|
||||
.terms p {{
|
||||
font-size: 9pt;
|
||||
color: #666;
|
||||
margin: 4px 0;
|
||||
}}
|
||||
.footer {{
|
||||
text-align: center;
|
||||
margin-top: 40px;
|
||||
font-size: 9pt;
|
||||
color: #999;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>QUOTATION</h1>
|
||||
<p class="number">#{quotation_number}</p>
|
||||
</div>
|
||||
|
||||
<div class="info-grid">
|
||||
<div class="info-block">
|
||||
<h3>Bill To</h3>
|
||||
<p>{customer_name}</p>
|
||||
<p>{customer_company}</p>
|
||||
<p>{customer_country}</p>
|
||||
</div>
|
||||
<div class="info-block">
|
||||
<h3>Quote Details</h3>
|
||||
<p>Date: {date}</p>
|
||||
<p>Valid Until: {valid_until}</p>
|
||||
<p>Currency: {currency}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Item</th>
|
||||
<th>Description</th>
|
||||
<th>Qty</th>
|
||||
<th>Unit</th>
|
||||
<th>Unit Price</th>
|
||||
<th>Total</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{items_rows}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<table>
|
||||
<tr class="amount-row"><td colspan="4"></td><td>Subtotal:</td><td>{subtotal}</td></tr>
|
||||
<tr class="amount-row"><td colspan="4"></td><td>Discount:</td><td>-{discount}</td></tr>
|
||||
<tr class="amount-row"><td colspan="4"></td><td>Shipping:</td><td>{shipping}</td></tr>
|
||||
<tr class="total-row"><td colspan="4"></td><td>TOTAL:</td><td>{total}</td></tr>
|
||||
</table>
|
||||
|
||||
<div class="terms">
|
||||
<h3>Terms & Conditions</h3>
|
||||
<p>Payment Terms: {payment_terms}</p>
|
||||
<p>Delivery Terms: {delivery_terms}</p>
|
||||
<p>Lead Time: {lead_time}</p>
|
||||
{notes_html}
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>Generated by TradeMate - {generated_at}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class PDFGenerator:
|
||||
@staticmethod
|
||||
def generate_quotation(data: Dict[str, Any]) -> Optional[bytes]:
|
||||
if not HAS_WEASYPRINT:
|
||||
return None
|
||||
|
||||
items = data.get("items", [])
|
||||
items_rows = ""
|
||||
for i, item in enumerate(items, 1):
|
||||
items_rows += (
|
||||
f"<tr>"
|
||||
f"<td>{item.get('product_name', '')}</td>"
|
||||
f"<td>{item.get('description', '') or ''}</td>"
|
||||
f"<td>{item.get('quantity', 0)}</td>"
|
||||
f"<td>{item.get('unit', 'pcs')}</td>"
|
||||
f"<td>{item.get('unit_price', 0):.2f}</td>"
|
||||
f"<td>{item.get('total_price', 0):.2f}</td>"
|
||||
f"</tr>"
|
||||
)
|
||||
|
||||
cur = data.get("currency", "USD")
|
||||
subtotal = f"{cur} {data.get('subtotal', 0):.2f}"
|
||||
discount = f"{cur} {data.get('discount', 0):.2f}" if data.get("discount") else f"{cur} 0.00"
|
||||
shipping = f"{cur} {data.get('shipping', 0):.2f}" if data.get("shipping") else f"{cur} 0.00"
|
||||
total = f"{cur} {data.get('total', 0):.2f}"
|
||||
|
||||
notes_html = ""
|
||||
if data.get("notes"):
|
||||
notes_html = f"<p>Notes: {data['notes']}</p>"
|
||||
|
||||
html = QUOTATION_TEMPLATE.format(
|
||||
quotation_number=data.get("quotation_number", "N/A"),
|
||||
customer_name=data.get("customer_name", ""),
|
||||
customer_company=data.get("customer_company", "") or "",
|
||||
customer_country=data.get("customer_country", "") or "",
|
||||
date=data.get("date", datetime.utcnow().strftime("%Y-%m-%d")),
|
||||
valid_until=data.get("valid_until", "N/A"),
|
||||
currency=cur,
|
||||
items_rows=items_rows,
|
||||
subtotal=subtotal,
|
||||
discount=discount,
|
||||
shipping=shipping,
|
||||
total=total,
|
||||
payment_terms=data.get("payment_terms", "N/A"),
|
||||
delivery_terms=data.get("delivery_terms", "N/A"),
|
||||
lead_time=data.get("lead_time", "N/A"),
|
||||
notes_html=notes_html,
|
||||
generated_at=datetime.utcnow().strftime("%Y-%m-%d %H:%M UTC"),
|
||||
)
|
||||
|
||||
pdf = HTML(string=html).write_pdf()
|
||||
return pdf
|
||||
|
||||
|
||||
pdf_generator = PDFGenerator()
|
||||
@@ -0,0 +1,217 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, desc
|
||||
from app.models.customer import Message, Conversation
|
||||
from app.models.user import User
|
||||
from app.models.preference import PreferenceAnalysis
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserPreferenceService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def record_selection(self, user_id: str, message_id: str, selected_index: int) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Message).where(Message.id == message_id)
|
||||
)
|
||||
msg = result.scalar_one_or_none()
|
||||
if not msg:
|
||||
return False
|
||||
msg.selected_suggestion = selected_index
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def record_edit(self, user_id: str, message_id: str, edited_text: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Message).where(Message.id == message_id)
|
||||
)
|
||||
msg = result.scalar_one_or_none()
|
||||
if not msg:
|
||||
return False
|
||||
msg.user_edited = edited_text
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def analyze_preferences(self, user_id: str) -> Dict[str, Any]:
|
||||
user_conv_subq = select(Conversation.id).where(
|
||||
Conversation.user_id == user_id
|
||||
).subquery()
|
||||
|
||||
count_result = await self.db.execute(
|
||||
select(func.count(Message.id)).where(
|
||||
and_(
|
||||
Message.conversation_id.in_(select(user_conv_subq)),
|
||||
Message.selected_suggestion.isnot(None),
|
||||
)
|
||||
)
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
if total < 3:
|
||||
return {"needs_more_data": True, "interaction_count": total}
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Message)
|
||||
.where(
|
||||
and_(
|
||||
Message.conversation_id.in_(select(user_conv_subq)),
|
||||
Message.selected_suggestion.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(desc(Message.created_at))
|
||||
.limit(100)
|
||||
)
|
||||
messages = result.scalars().all()
|
||||
|
||||
tone_counts = {}
|
||||
edit_count = 0
|
||||
total_chars_saved = 0
|
||||
greeting_patterns = []
|
||||
signoff_patterns = []
|
||||
|
||||
for m in messages:
|
||||
suggestions = m.ai_suggestions or []
|
||||
selected = m.selected_suggestion
|
||||
if suggestions and selected is not None and selected < len(suggestions):
|
||||
tone = suggestions[selected].get("tone", "unknown")
|
||||
tone_counts[tone] = tone_counts.get(tone, 0) + 1
|
||||
|
||||
if m.user_edited:
|
||||
edit_count += 1
|
||||
if suggestions and selected is not None and selected < len(suggestions):
|
||||
original = suggestions[selected].get("reply", "")
|
||||
total_chars_saved += abs(len(original) - len(m.user_edited))
|
||||
|
||||
preferred_tone = max(tone_counts, key=tone_counts.get) if tone_counts else "professional"
|
||||
edit_ratio = edit_count / len(messages) if messages else 0
|
||||
avg_edit_size = total_chars_saved / edit_count if edit_count > 0 else 0
|
||||
|
||||
greeting_style = self._extract_greeting_style(messages)
|
||||
sign_off_style = self._extract_sign_off_style(messages)
|
||||
|
||||
preferences = {
|
||||
"preferred_tone": preferred_tone,
|
||||
"edit_ratio": edit_ratio,
|
||||
"avg_edit_size": avg_edit_size,
|
||||
"greeting_style": greeting_style,
|
||||
"sign_off_style": sign_off_style,
|
||||
"tone_distribution": tone_counts,
|
||||
"interaction_count": len(messages),
|
||||
"confidence": min(1.0, len(messages) / 20),
|
||||
}
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(PreferenceAnalysis).where(PreferenceAnalysis.user_id == user_id)
|
||||
)
|
||||
analysis = existing.scalar_one_or_none()
|
||||
|
||||
if analysis:
|
||||
analysis.preferred_tone = preferred_tone
|
||||
analysis.greeting_style = greeting_style
|
||||
analysis.sign_off_style = sign_off_style
|
||||
analysis.analysis_data = preferences
|
||||
analysis.interaction_count = len(messages)
|
||||
analysis.confidence = preferences["confidence"]
|
||||
analysis.last_analysis_at = datetime.utcnow()
|
||||
else:
|
||||
analysis = PreferenceAnalysis(
|
||||
user_id=user_id,
|
||||
task_type="reply",
|
||||
preferred_tone=preferred_tone,
|
||||
greeting_style=greeting_style,
|
||||
sign_off_style=sign_off_style,
|
||||
analysis_data=preferences,
|
||||
interaction_count=len(messages),
|
||||
confidence=preferences["confidence"],
|
||||
last_analysis_at=datetime.utcnow(),
|
||||
)
|
||||
self.db.add(analysis)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
await self._update_user_settings(user_id, preferences)
|
||||
return preferences
|
||||
|
||||
async def get_preference_context(self, user_id: str, task_type: str = "reply") -> Optional[str]:
|
||||
result = await self.db.execute(
|
||||
select(PreferenceAnalysis).where(
|
||||
and_(
|
||||
PreferenceAnalysis.user_id == user_id,
|
||||
PreferenceAnalysis.task_type == task_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
analysis = result.scalar_one_or_none()
|
||||
if not analysis or analysis.confidence < 0.3:
|
||||
return None
|
||||
|
||||
parts = []
|
||||
if analysis.preferred_tone:
|
||||
parts.append(f"user's preferred tone: {analysis.preferred_tone}")
|
||||
if analysis.greeting_style:
|
||||
parts.append(f"user's typical greeting: {analysis.greeting_style}")
|
||||
if analysis.sign_off_style:
|
||||
parts.append(f"user's typical sign-off: {analysis.sign_off_style}")
|
||||
|
||||
if parts:
|
||||
return "This user prefers: " + "; ".join(parts) + "."
|
||||
return None
|
||||
|
||||
async def get_analysis(self, user_id: str) -> Dict[str, Any]:
|
||||
result = await self.db.execute(
|
||||
select(PreferenceAnalysis).where(PreferenceAnalysis.user_id == user_id)
|
||||
)
|
||||
analysis = result.scalar_one_or_none()
|
||||
if not analysis:
|
||||
return {"analyzed": False, "interaction_count": 0, "confidence": 0}
|
||||
|
||||
return {
|
||||
"analyzed": True,
|
||||
"preferred_tone": analysis.preferred_tone,
|
||||
"greeting_style": analysis.greeting_style,
|
||||
"sign_off_style": analysis.sign_off_style,
|
||||
"interaction_count": analysis.interaction_count,
|
||||
"confidence": analysis.confidence,
|
||||
"last_analysis_at": analysis.last_analysis_at.isoformat() if analysis.last_analysis_at else None,
|
||||
}
|
||||
|
||||
async def _update_user_settings(self, user_id: str, preferences: Dict[str, Any]):
|
||||
result = await self.db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user:
|
||||
settings = dict(user.settings or {})
|
||||
settings["preferred_tone"] = preferences.get("preferred_tone", settings.get("reply_tone", "professional"))
|
||||
settings["ai_learning"] = {
|
||||
"analyzed": True,
|
||||
"confidence": preferences.get("confidence", 0),
|
||||
"edit_ratio": preferences.get("edit_ratio", 0),
|
||||
"greeting_style": preferences.get("greeting_style", ""),
|
||||
"sign_off_style": preferences.get("sign_off_style", ""),
|
||||
}
|
||||
user.settings = settings
|
||||
await self.db.flush()
|
||||
|
||||
def _extract_greeting_style(self, messages: List[Message]) -> str:
|
||||
for m in messages:
|
||||
text = m.user_edited or (m.ai_suggestions[m.selected_suggestion].get("reply", "") if m.selected_suggestion is not None and m.ai_suggestions and m.selected_suggestion < len(m.ai_suggestions) else "")
|
||||
if text:
|
||||
first_word = text.strip().split()[0] if text.strip() else ""
|
||||
if first_word in ["Dear", "Hi", "Hello", "Hey", "To"]:
|
||||
return first_word
|
||||
return ""
|
||||
|
||||
def _extract_sign_off_style(self, messages: List[Message]) -> str:
|
||||
for m in messages:
|
||||
text = m.user_edited or (m.ai_suggestions[m.selected_suggestion].get("reply", "") if m.selected_suggestion is not None and m.ai_suggestions and m.selected_suggestion < len(m.ai_suggestions) else "")
|
||||
if text:
|
||||
words = text.strip().split()
|
||||
if len(words) >= 3:
|
||||
last_three = " ".join(words[-3:])
|
||||
for signoff in ["Best regards", "Best wishes", "Sincerely", "Cheers", "Regards", "Yours"]:
|
||||
if signoff in last_three:
|
||||
return signoff
|
||||
return ""
|
||||
@@ -0,0 +1,156 @@
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from app.config import settings
|
||||
from app.models.device import Device
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushService:
|
||||
def __init__(self, db: Optional[AsyncSession] = None):
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def send_notification(user_id: str, title: str, content: str, payload: Optional[Dict[str, Any]] = None) -> bool:
|
||||
logger.info(f"[PUSH] user={user_id} title={title} content={content}")
|
||||
try:
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
result = loop.run_until_complete(
|
||||
PushService._send_via_wechat(user_id, title, content, payload)
|
||||
)
|
||||
loop.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Push send failed (logged only): {e}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def send_bulk(user_ids: List[str], title: str, content: str, payload: Optional[Dict[str, Any]] = None) -> int:
|
||||
sent = 0
|
||||
for uid in user_ids:
|
||||
if PushService.send_notification(uid, title, content, payload):
|
||||
sent += 1
|
||||
return sent
|
||||
|
||||
async def send_async(self, user_id: str, title: str, content: str, payload: Optional[Dict[str, Any]] = None) -> bool:
|
||||
logger.info(f"[PUSH_ASYNC] user={user_id} title={title}")
|
||||
|
||||
result = await self._send_via_wechat(user_id, title, content, payload)
|
||||
|
||||
await self._save_in_app_notification(user_id, title, content, payload)
|
||||
|
||||
return result
|
||||
|
||||
async def register_device(self, user_id: str, client_id: str, platform: str = "weapp", push_token: Optional[str] = None, device_info: Optional[Dict] = None) -> Device:
|
||||
result = await self.db.execute(
|
||||
select(Device).where(
|
||||
and_(Device.user_id == user_id, Device.client_id == client_id)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.platform = platform
|
||||
existing.push_token = push_token
|
||||
existing.device_info = device_info or {}
|
||||
existing.is_active = True
|
||||
await self.db.flush()
|
||||
return existing
|
||||
|
||||
device = Device(
|
||||
user_id=user_id,
|
||||
platform=platform,
|
||||
push_token=push_token,
|
||||
client_id=client_id,
|
||||
device_info=device_info or {},
|
||||
)
|
||||
self.db.add(device)
|
||||
await self.db.flush()
|
||||
return device
|
||||
|
||||
async def get_user_devices(self, user_id: str) -> List[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Device).where(
|
||||
and_(Device.user_id == user_id, Device.is_active == True)
|
||||
)
|
||||
)
|
||||
devices = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": str(d.id),
|
||||
"platform": d.platform,
|
||||
"client_id": d.client_id,
|
||||
"device_info": d.device_info,
|
||||
"created_at": d.created_at.isoformat() if d.created_at else None,
|
||||
}
|
||||
for d in devices
|
||||
]
|
||||
|
||||
async def unregister_device(self, user_id: str, client_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Device).where(
|
||||
and_(Device.user_id == user_id, Device.client_id == client_id)
|
||||
)
|
||||
)
|
||||
device = result.scalar_one_or_none()
|
||||
if not device:
|
||||
return False
|
||||
device.is_active = False
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _send_via_wechat(user_id: str, title: str, content: str, payload: Optional[Dict] = None) -> bool:
|
||||
if not settings.WECHAT_APP_ID or not settings.WECHAT_APP_SECRET:
|
||||
logger.debug("WeChat push not configured, falling back to log")
|
||||
return True
|
||||
|
||||
try:
|
||||
from app.services.wechat import wechat_service
|
||||
access_token = await wechat_service._get_access_token()
|
||||
if not access_token:
|
||||
logger.warning("Cannot get WeChat access token for push")
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
"https://api.weixin.qq.com/cgi-bin/message/subscribe/send",
|
||||
params={"access_token": access_token},
|
||||
json={
|
||||
"touser": user_id,
|
||||
"template_id": settings.WECHAT_PUSH_TEMPLATE_ID or "",
|
||||
"data": {
|
||||
"thing1": {"value": title[:20]},
|
||||
"thing2": {"value": content[:20]},
|
||||
},
|
||||
"miniprogram_state": "formal",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode", 0) != 0:
|
||||
logger.warning(f"WeChat push failed: {data}")
|
||||
return False
|
||||
logger.info(f"WeChat push sent to user {user_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"WeChat push error: {e}")
|
||||
return False
|
||||
|
||||
async def _save_in_app_notification(self, user_id: str, title: str, content: str, payload: Optional[Dict] = None):
|
||||
if not self.db:
|
||||
return
|
||||
try:
|
||||
from app.services.notification import NotificationService
|
||||
await NotificationService.create_notification(
|
||||
self.db, user_id, title, content,
|
||||
notification_type="push",
|
||||
reference_type=(payload or {}).get("reference_type"),
|
||||
reference_id=(payload or {}).get("reference_id"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save in-app notification: {e}")
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy import select, and_, or_
|
||||
from app.models.quotation import Quotation, QuotationItem
|
||||
from app.models.customer import Customer
|
||||
from app.models.user import Product
|
||||
from app.ai.router import get_ai_router
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -90,6 +92,135 @@ class QuotationService:
|
||||
await self.db.flush()
|
||||
return await self._to_dict(q)
|
||||
|
||||
async def generate_from_inquiry(
|
||||
self, user_id: str, inquiry_text: str, customer_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
ai = get_ai_router()
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_requests": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_name": {"type": "string"},
|
||||
"quantity": {"type": "integer"},
|
||||
"unit": {"type": "string"},
|
||||
"specifications": {"type": "string"},
|
||||
"target_price": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"payment_terms": {"type": "string"},
|
||||
"delivery_terms": {"type": "string"},
|
||||
"urgency": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
extract_result = await ai.extract(inquiry_text, schema)
|
||||
extracted = extract_result.get("data", {})
|
||||
product_requests = extracted.get("product_requests", [])
|
||||
|
||||
if not product_requests:
|
||||
schema_simple = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_name": {"type": "string"},
|
||||
"quantity": {"type": "integer"},
|
||||
"specifications": {"type": "string"},
|
||||
},
|
||||
}
|
||||
extract_result = await ai.extract(inquiry_text, schema_simple)
|
||||
extracted = extract_result.get("data", {})
|
||||
if extracted.get("product_name"):
|
||||
product_requests = [{
|
||||
"product_name": extracted["product_name"],
|
||||
"quantity": extracted.get("quantity", 1),
|
||||
"unit": "pcs",
|
||||
"specifications": extracted.get("specifications", ""),
|
||||
}]
|
||||
|
||||
product_result = await self.db.execute(
|
||||
select(Product).where(
|
||||
and_(
|
||||
Product.user_id == user_id,
|
||||
Product.is_active == True,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_products = product_result.scalars().all()
|
||||
|
||||
matched_products = []
|
||||
for req in product_requests:
|
||||
req_name = req.get("product_name", "").lower()
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for p in user_products:
|
||||
score = 0
|
||||
p_name = (p.name or "").lower()
|
||||
p_name_en = (p.name_en or "").lower()
|
||||
|
||||
if req_name in p_name or p_name in req_name:
|
||||
score += 3
|
||||
if req_name in p_name_en or p_name_en in req_name:
|
||||
score += 2
|
||||
|
||||
keywords = p.keywords or []
|
||||
for kw in keywords:
|
||||
if isinstance(kw, str) and kw.lower() in req_name:
|
||||
score += 1
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = p
|
||||
|
||||
if best_match and best_score > 0:
|
||||
unit_price = float(best_match.price) if best_match.price else 0
|
||||
quantity = req.get("quantity", 1)
|
||||
matched_products.append({
|
||||
"product_id": str(best_match.id),
|
||||
"product_name": best_match.name,
|
||||
"description": best_match.description_en or best_match.description,
|
||||
"quantity": quantity,
|
||||
"unit_price": unit_price,
|
||||
"total_price": unit_price * quantity,
|
||||
"unit": req.get("unit", "pcs"),
|
||||
"match_score": best_score,
|
||||
})
|
||||
else:
|
||||
matched_products.append({
|
||||
"product_id": None,
|
||||
"product_name": req.get("product_name", "Unknown"),
|
||||
"description": req.get("specifications", ""),
|
||||
"quantity": req.get("quantity", 1),
|
||||
"unit_price": 0,
|
||||
"total_price": 0,
|
||||
"unit": req.get("unit", "pcs"),
|
||||
"match_score": 0,
|
||||
})
|
||||
|
||||
subtotal = sum(p["total_price"] for p in matched_products)
|
||||
total = subtotal
|
||||
|
||||
suggested_quotation = {
|
||||
"title": f"Quotation - {', '.join(p['product_name'] for p in matched_products[:3])}",
|
||||
"items": matched_products,
|
||||
"subtotal": subtotal,
|
||||
"total": total,
|
||||
"payment_terms": extracted.get("payment_terms", "T/T"),
|
||||
"delivery_terms": extracted.get("delivery_terms", "FOB"),
|
||||
"lead_time": "15-20 days" if extracted.get("urgency") != "urgent" else "7-10 days",
|
||||
"notes": f"Generated from customer inquiry: {inquiry_text[:100]}..." if len(inquiry_text) > 100 else f"Generated from customer inquiry: {inquiry_text}",
|
||||
"extracted_data": extracted,
|
||||
"matched_count": len([p for p in matched_products if p["product_id"]]),
|
||||
"unmatched_count": len([p for p in matched_products if not p["product_id"]]),
|
||||
}
|
||||
|
||||
return suggested_quotation
|
||||
|
||||
async def generate_quotation_text(self, q: Quotation) -> str:
|
||||
items_result = await self.db.execute(
|
||||
select(QuotationItem).where(QuotationItem.quotation_id == q.id)
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.customer import Customer, Conversation, Message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SilentPatternService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def analyze_silent_risk(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
cutoff_3d = datetime.utcnow() - timedelta(days=3)
|
||||
cutoff_7d = datetime.utcnow() - timedelta(days=7)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(
|
||||
Customer.user_id == user_id,
|
||||
Customer.status.in_(["lead", "negotiating"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
customers = result.scalars().all()
|
||||
|
||||
risk_scores = []
|
||||
for c in customers:
|
||||
score, reasons = await self._calculate_risk_score(c, cutoff_3d, cutoff_7d)
|
||||
if score > 0:
|
||||
risk_scores.append({
|
||||
"customer_id": str(c.id),
|
||||
"name": c.name,
|
||||
"company": c.company,
|
||||
"country": c.country,
|
||||
"status": c.status,
|
||||
"estimated_value": c.estimated_value,
|
||||
"last_contact_at": c.last_contact_at.isoformat() if c.last_contact_at else None,
|
||||
"silence_days": (datetime.utcnow() - c.last_contact_at).days if c.last_contact_at else 0,
|
||||
"risk_score": score,
|
||||
"risk_level": self._risk_level(score),
|
||||
"reasons": reasons,
|
||||
})
|
||||
|
||||
risk_scores.sort(key=lambda x: x["risk_score"], reverse=True)
|
||||
return risk_scores
|
||||
|
||||
async def _calculate_risk_score(
|
||||
self, customer: Customer, cutoff_3d: datetime, cutoff_7d: datetime
|
||||
) -> tuple:
|
||||
score = 0
|
||||
reasons = []
|
||||
|
||||
if not customer.last_contact_at:
|
||||
return (0, [])
|
||||
|
||||
silence_days = (datetime.utcnow() - customer.last_contact_at).days
|
||||
|
||||
if silence_days >= 7:
|
||||
score += 40
|
||||
reasons.append(f"沉默超过7天")
|
||||
elif silence_days >= 3:
|
||||
score += 20
|
||||
reasons.append(f"沉默超过3天")
|
||||
|
||||
conv_query = await self.db.execute(
|
||||
select(Conversation).where(
|
||||
and_(
|
||||
Conversation.customer_id == customer.id,
|
||||
Conversation.user_id == customer.user_id,
|
||||
)
|
||||
).order_by(Conversation.created_at.desc()).limit(1)
|
||||
)
|
||||
conv = conv_query.scalar_one_or_none()
|
||||
if not conv:
|
||||
return (score, reasons)
|
||||
|
||||
msg_count_query = await self.db.execute(
|
||||
select(func.count(Message.id)).where(
|
||||
and_(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.direction == "inbound",
|
||||
)
|
||||
)
|
||||
)
|
||||
inbound_count = msg_count_query.scalar() or 0
|
||||
|
||||
if inbound_count >= 5 and silence_days >= 3:
|
||||
score += 20
|
||||
reasons.append(f"前期沟通频繁({inbound_count}条)后突然沉默")
|
||||
|
||||
if customer.status == "lead" and silence_days >= 3:
|
||||
score += 15
|
||||
reasons.append("潜在客户阶段需及时跟进")
|
||||
|
||||
if customer.status == "negotiating" and silence_days >= 2:
|
||||
score += 25
|
||||
reasons.append("谈判阶段客户需保持热度")
|
||||
|
||||
recent_query = await self.db.execute(
|
||||
select(Message).where(
|
||||
and_(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.created_at >= cutoff_7d,
|
||||
)
|
||||
).order_by(Message.created_at.desc()).limit(3)
|
||||
)
|
||||
recent_msgs = recent_query.scalars().all()
|
||||
|
||||
if recent_msgs:
|
||||
last_inbound = None
|
||||
for m in recent_msgs:
|
||||
if m.direction == "inbound":
|
||||
last_inbound = m
|
||||
break
|
||||
if last_inbound and silence_days >= 1:
|
||||
content_lower = last_inbound.content.lower()
|
||||
closing_signals = ["i'll think", "let me check", "too expensive", "high price", "not now", "maybe later", "considering"]
|
||||
for signal in closing_signals:
|
||||
if signal in content_lower:
|
||||
score += 15
|
||||
reasons.append(f"客户回复含消极信号: \"{signal}\"")
|
||||
break
|
||||
|
||||
return (min(score, 100), reasons)
|
||||
|
||||
def _risk_level(self, score: int) -> str:
|
||||
if score >= 70:
|
||||
return "high"
|
||||
elif score >= 40:
|
||||
return "medium"
|
||||
elif score >= 20:
|
||||
return "low"
|
||||
return "minimal"
|
||||
|
||||
async def get_suggestions(self, user_id: str, customer_id: str) -> List[str]:
|
||||
score_result = await self.analyze_silent_risk(user_id)
|
||||
customer_scores = [s for s in score_result if s["customer_id"] == customer_id]
|
||||
if not customer_scores:
|
||||
return []
|
||||
|
||||
score = customer_scores[0]
|
||||
suggestions = []
|
||||
silence_days = score["silence_days"]
|
||||
|
||||
if silence_days >= 7:
|
||||
suggestions.extend([
|
||||
f"客户{score['name']}已沉默{silence_days}天,建议发送产品更新或行业资讯重新激活",
|
||||
"考虑提供限时优惠或样品折扣打动客户",
|
||||
])
|
||||
elif silence_days >= 3:
|
||||
suggestions.extend([
|
||||
f"客户{score['name']}沉默{silence_days}天,建议发送跟进消息询问是否有进一步需求",
|
||||
"可分享相关案例或成功故事保持客户兴趣",
|
||||
])
|
||||
|
||||
if "negotiating" in score.get("status", ""):
|
||||
suggestions.append("谈判阶段客户,建议主动提供更多产品细节或定制方案")
|
||||
|
||||
if "消极信号" in str(score.get("reasons", [])):
|
||||
suggestions.append("客户曾表达价格顾虑,建议重新审视报价或提供增值服务")
|
||||
|
||||
if not suggestions:
|
||||
suggestions.append("客户状态良好,建议保持定期跟进节奏")
|
||||
|
||||
return suggestions
|
||||
@@ -0,0 +1,201 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_
|
||||
from app.models.team import Team, TeamMember
|
||||
from app.models.user import User
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_team(self, owner_id: str, name: str, description: str = None) -> Dict[str, Any]:
|
||||
existing = await self.db.execute(
|
||||
select(Team).where(and_(Team.owner_id == owner_id, Team.is_active == True))
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("You already own an active team")
|
||||
|
||||
user_result = await self.db.execute(select(User).where(User.id == owner_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
max_members = 20 if user.tier == "enterprise" else (10 if user.tier == "pro" else 5)
|
||||
|
||||
team = Team(
|
||||
name=name,
|
||||
owner_id=owner_id,
|
||||
description=description,
|
||||
max_members=max_members,
|
||||
tier=user.tier,
|
||||
)
|
||||
self.db.add(team)
|
||||
await self.db.flush()
|
||||
|
||||
member = TeamMember(
|
||||
team_id=team.id,
|
||||
user_id=owner_id,
|
||||
role="owner",
|
||||
status="active",
|
||||
)
|
||||
self.db.add(member)
|
||||
await self.db.flush()
|
||||
|
||||
return await self._to_dict(team, include_members=True)
|
||||
|
||||
async def get_team(self, team_id: str, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
result = await self.db.execute(
|
||||
select(Team).where(Team.id == team_id)
|
||||
)
|
||||
team = result.scalar_one_or_none()
|
||||
if not team:
|
||||
return None
|
||||
|
||||
is_member = await self.db.execute(
|
||||
select(TeamMember).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.user_id == user_id)
|
||||
)
|
||||
)
|
||||
if not is_member.scalar_one_or_none():
|
||||
return None
|
||||
|
||||
return await self._to_dict(team, include_members=True)
|
||||
|
||||
async def list_user_teams(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
member_result = await self.db.execute(
|
||||
select(TeamMember.team_id).where(TeamMember.user_id == user_id)
|
||||
)
|
||||
team_ids = [r[0] for r in member_result.all()]
|
||||
|
||||
if not team_ids:
|
||||
return []
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Team).where(Team.id.in_(team_ids), Team.is_active == True)
|
||||
)
|
||||
teams = result.scalars().all()
|
||||
|
||||
return [await self._to_dict(t) for t in teams]
|
||||
|
||||
async def invite_member(self, team_id: str, owner_id: str, user_id: str) -> Dict[str, Any]:
|
||||
team_result = await self.db.execute(
|
||||
select(Team).where(and_(Team.id == team_id, Team.owner_id == owner_id))
|
||||
)
|
||||
team = team_result.scalar_one_or_none()
|
||||
if not team:
|
||||
raise ValueError("Team not found or not authorized")
|
||||
|
||||
member_count = await self.db.execute(
|
||||
select(func.count(TeamMember.id)).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.status == "active")
|
||||
)
|
||||
)
|
||||
if (member_count.scalar() or 0) >= team.max_members:
|
||||
raise ValueError("Team member limit reached")
|
||||
|
||||
existing = await self.db.execute(
|
||||
select(TeamMember).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.user_id == user_id)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("User is already a member")
|
||||
|
||||
member = TeamMember(
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
role="member",
|
||||
invited_by=owner_id,
|
||||
status="active",
|
||||
)
|
||||
self.db.add(member)
|
||||
await self.db.flush()
|
||||
|
||||
return {"user_id": user_id, "role": "member", "status": "active"}
|
||||
|
||||
async def remove_member(self, team_id: str, owner_id: str, user_id: str) -> bool:
|
||||
team_result = await self.db.execute(
|
||||
select(Team).where(and_(Team.id == team_id, Team.owner_id == owner_id))
|
||||
)
|
||||
team = team_result.scalar_one_or_none()
|
||||
if not team:
|
||||
return False
|
||||
|
||||
result = await self.db.execute(
|
||||
select(TeamMember).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.user_id == user_id)
|
||||
)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if not member or member.role == "owner":
|
||||
return False
|
||||
|
||||
await self.db.delete(member)
|
||||
return True
|
||||
|
||||
async def leave_team(self, team_id: str, user_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(TeamMember).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.user_id == user_id)
|
||||
)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if not member or member.role == "owner":
|
||||
return False
|
||||
|
||||
await self.db.delete(member)
|
||||
return True
|
||||
|
||||
async def update_role(self, team_id: str, owner_id: str, user_id: str, role: str) -> bool:
|
||||
team_result = await self.db.execute(
|
||||
select(Team).where(and_(Team.id == team_id, Team.owner_id == owner_id))
|
||||
)
|
||||
if not team_result.scalar_one_or_none():
|
||||
return False
|
||||
|
||||
result = await self.db.execute(
|
||||
select(TeamMember).where(
|
||||
and_(TeamMember.team_id == team_id, TeamMember.user_id == user_id)
|
||||
)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if not member or member.role == "owner":
|
||||
return False
|
||||
|
||||
member.role = role
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
async def _to_dict(self, team: Team, include_members: bool = False) -> Dict[str, Any]:
|
||||
result = {
|
||||
"id": str(team.id),
|
||||
"name": team.name,
|
||||
"owner_id": str(team.owner_id),
|
||||
"description": team.description,
|
||||
"tier": team.tier,
|
||||
"is_active": team.is_active,
|
||||
"created_at": team.created_at.isoformat() if team.created_at else None,
|
||||
}
|
||||
|
||||
if include_members:
|
||||
members_result = await self.db.execute(
|
||||
select(TeamMember).where(TeamMember.team_id == team.id)
|
||||
)
|
||||
members = members_result.scalars().all()
|
||||
result["members"] = [
|
||||
{
|
||||
"user_id": str(m.user_id),
|
||||
"role": m.role,
|
||||
"status": m.status,
|
||||
"joined_at": m.joined_at.isoformat() if m.joined_at else None,
|
||||
}
|
||||
for m in members
|
||||
]
|
||||
result["member_count"] = len([m for m in members if m.status == "active"])
|
||||
|
||||
return result
|
||||
@@ -47,6 +47,7 @@ class TranslationService:
|
||||
async def generate_reply(
|
||||
self, inquiry: str, context: Optional[Dict[str, Any]] = None,
|
||||
tone: str = "professional", count: int = 3,
|
||||
preference_context: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
similar = await self.corpus.find_similar(inquiry, "reply")
|
||||
if similar and count > 1:
|
||||
@@ -57,7 +58,7 @@ class TranslationService:
|
||||
|
||||
for t in tones:
|
||||
try:
|
||||
result = await self.ai.reply(inquiry, context, t)
|
||||
result = await self.ai.reply(inquiry, context, t, preference_context)
|
||||
results.append({
|
||||
"reply": result.get("reply", ""),
|
||||
"tone": t,
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import edge_tts
|
||||
HAS_EDGE_TTS = True
|
||||
except ImportError:
|
||||
HAS_EDGE_TTS = False
|
||||
logger.warning("edge-tts not installed, TTS disabled")
|
||||
|
||||
VOICE_MAP = {
|
||||
"zh": "zh-CN-XiaoxiaoNeural",
|
||||
"en": "en-US-AriaNeural",
|
||||
"ja": "ja-JP-NanamiNeural",
|
||||
"ko": "ko-KR-SunHiNeural",
|
||||
"fr": "fr-FR-DeniseNeural",
|
||||
"de": "de-DE-KatjaNeural",
|
||||
"es": "es-ES-ElviraNeural",
|
||||
"pt": "pt-BR-FranciscaNeural",
|
||||
"ru": "ru-RU-SvetlanaNeural",
|
||||
"ar": "ar-SA-ZariyahNeural",
|
||||
}
|
||||
|
||||
SUPPORTED_LANGS = list(VOICE_MAP.keys())
|
||||
|
||||
|
||||
class TextToSpeechService:
|
||||
@staticmethod
|
||||
async def synthesize(text: str, lang: str = "en", rate: str = "0%", pitch: str = "0Hz") -> Optional[bytes]:
|
||||
if not HAS_EDGE_TTS:
|
||||
logger.warning("edge-tts not available")
|
||||
return None
|
||||
|
||||
voice = VOICE_MAP.get(lang, VOICE_MAP["en"])
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
|
||||
audio_data = b""
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
audio_data += chunk["data"]
|
||||
return audio_data if audio_data else None
|
||||
except Exception as e:
|
||||
logger.error(f"TTS failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
tts_service = TextToSpeechService()
|
||||
@@ -0,0 +1,80 @@
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
import logging
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeChatService:
|
||||
def __init__(self):
|
||||
self.app_id = settings.WECHAT_APP_ID
|
||||
self.app_secret = settings.WECHAT_APP_SECRET
|
||||
self.api_base = "https://api.weixin.qq.com"
|
||||
|
||||
async def code2session(self, js_code: str) -> Optional[Dict[str, Any]]:
|
||||
if not self.app_id or not self.app_secret:
|
||||
logger.warning("WeChat not configured")
|
||||
return None
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.api_base}/sns/jscode2session",
|
||||
params={
|
||||
"appid": self.app_id,
|
||||
"secret": self.app_secret,
|
||||
"js_code": js_code,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
|
||||
if data.get("errcode", 0) != 0:
|
||||
logger.error(f"WeChat code2session failed: {data}")
|
||||
return None
|
||||
|
||||
return {
|
||||
"openid": data.get("openid"),
|
||||
"session_key": data.get("session_key"),
|
||||
"unionid": data.get("unionid"),
|
||||
}
|
||||
|
||||
async def get_phone_number(self, code: str) -> Optional[str]:
|
||||
if not self.app_id or not self.app_secret:
|
||||
return None
|
||||
|
||||
access_token = await self._get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.api_base}/wxa/business/getuserphonenumber",
|
||||
params={"access_token": access_token},
|
||||
json={"code": code},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("errcode", 0) != 0:
|
||||
logger.error(f"WeChat getPhoneNumber failed: {data}")
|
||||
return None
|
||||
|
||||
return data.get("phone_info", {}).get("phoneNumber")
|
||||
|
||||
async def _get_access_token(self) -> Optional[str]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.api_base}/cgi-bin/token",
|
||||
params={
|
||||
"grant_type": "client_credential",
|
||||
"appid": self.app_id,
|
||||
"secret": self.app_secret,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
return data.get("access_token")
|
||||
|
||||
|
||||
wechat_service = WeChatService()
|
||||
@@ -85,6 +85,48 @@ class WhatsAppService:
|
||||
)
|
||||
return resp.status_code == 200
|
||||
|
||||
async def send_media(self, to: str, media_url: str, media_type: str = "image", caption: Optional[str] = None) -> bool:
|
||||
if not self.api_token or not self.phone_number_id:
|
||||
return False
|
||||
|
||||
body = {
|
||||
"messaging_product": "whatsapp",
|
||||
"to": to,
|
||||
"type": media_type,
|
||||
media_type: {"link": media_url},
|
||||
}
|
||||
if caption:
|
||||
body[media_type]["caption"] = caption
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers={"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"},
|
||||
json=body,
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error(f"WhatsApp media send failed: {resp.text}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def mark_as_read(self, message_id: str) -> bool:
|
||||
if not self.api_token:
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers={"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"},
|
||||
json={
|
||||
"messaging_product": "whatsapp",
|
||||
"status": "read",
|
||||
"message_id": message_id,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
|
||||
def parse_webhook(self, body: Dict) -> Optional[Dict]:
|
||||
try:
|
||||
entry = body.get("entry", [{}])[0]
|
||||
@@ -96,14 +138,29 @@ class WhatsAppService:
|
||||
return None
|
||||
|
||||
msg = messages[0]
|
||||
msg_type = msg.get("type", "text")
|
||||
content = ""
|
||||
|
||||
if msg_type == "text":
|
||||
content = msg.get("text", {}).get("body", "")
|
||||
elif msg_type in ("image", "document", "audio", "video"):
|
||||
media = msg.get(msg_type, {})
|
||||
content = media.get("caption", "") or media.get("filename", "") or f"[{msg_type}]"
|
||||
|
||||
return {
|
||||
"from": msg.get("from"),
|
||||
"text": msg.get("text", {}).get("body", ""),
|
||||
"text": content,
|
||||
"msg_id": msg.get("id"),
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"type": msg.get("type", "text"),
|
||||
"type": msg_type,
|
||||
"profile_name": value.get("contacts", [{}])[0].get("profile", {}).get("name"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse WhatsApp webhook: {e}")
|
||||
return None
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
+155
-57
@@ -10,6 +10,9 @@ logger = logging.getLogger(__name__)
|
||||
def check_silent_customers():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.customer import Customer
|
||||
from app.models.user import User
|
||||
from app.services.push import PushService
|
||||
from app.services.notification import NotificationService
|
||||
|
||||
async def _check():
|
||||
async with AsyncSessionLocal() as db:
|
||||
@@ -27,12 +30,26 @@ def check_silent_customers():
|
||||
)
|
||||
customers = result.scalars().all()
|
||||
for c in customers:
|
||||
if days == 3:
|
||||
logger.info(f"Customer {c.name} silent for 3 days")
|
||||
elif days == 7:
|
||||
logger.info(f"Customer {c.name} silent for 7 days - upgrade")
|
||||
else:
|
||||
logger.info(f"Customer {c.name} silent for 14 days - recommend new approach")
|
||||
messages = {
|
||||
3: ("跟进提醒", f"客户 {c.name} 已沉默3天,建议发送跟进消息"),
|
||||
7: ("跟进升级", f"客户 {c.name} 已沉默1周,建议发送优惠或新产品信息"),
|
||||
14: ("跟进提示", f"客户 {c.name} 已沉默14天,建议换话题重新接触"),
|
||||
}
|
||||
title, content = messages.get(days, ("跟进提醒", f"客户 {c.name} 已沉默{days}天"))
|
||||
logger.info(f"Customer {c.name} silent for {days} days")
|
||||
|
||||
user_result = await db.execute(
|
||||
select(User).where(User.id == c.user_id)
|
||||
)
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user:
|
||||
PushService.send_notification(c.user_id, title, content)
|
||||
await NotificationService.create_notification(
|
||||
db, c.user_id, title, content,
|
||||
notification_type="customer_silent",
|
||||
reference_type="customer",
|
||||
reference_id=str(c.id),
|
||||
)
|
||||
|
||||
import asyncio
|
||||
asyncio.run(_check())
|
||||
@@ -59,6 +76,8 @@ def batch_translate_texts(texts: list, target_lang: str, user_id: str):
|
||||
def generate_quotation_pdf(quotation_id: str):
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.quotation import Quotation, QuotationItem
|
||||
from app.models.customer import Customer
|
||||
from app.services.pdf_generator import pdf_generator
|
||||
|
||||
async def _generate():
|
||||
async with AsyncSessionLocal() as db:
|
||||
@@ -74,62 +93,60 @@ def generate_quotation_pdf(quotation_id: str):
|
||||
)
|
||||
items = items_result.scalars().all()
|
||||
|
||||
pdf_content = generate_pdf_text(q, items)
|
||||
customer = None
|
||||
if q.customer_id:
|
||||
cust_result = await db.execute(
|
||||
select(Customer).where(Customer.id == q.customer_id)
|
||||
)
|
||||
customer = cust_result.scalar_one_or_none()
|
||||
|
||||
return {"pdf_content": pdf_content, "quotation_id": str(q.id)}
|
||||
data = {
|
||||
"quotation_number": f"{str(q.id)[:8].upper()}",
|
||||
"customer_name": customer.name if customer else "",
|
||||
"customer_company": customer.company if customer else "",
|
||||
"customer_country": customer.country if customer else "",
|
||||
"date": q.created_at.strftime("%Y-%m-%d") if q.created_at else "",
|
||||
"valid_until": q.valid_until or "",
|
||||
"currency": q.currency or "USD",
|
||||
"items": [
|
||||
{
|
||||
"product_name": i.product_name,
|
||||
"description": i.description,
|
||||
"quantity": i.quantity,
|
||||
"unit_price": i.unit_price,
|
||||
"total_price": i.total_price,
|
||||
"unit": i.unit or "pcs",
|
||||
}
|
||||
for i in items
|
||||
],
|
||||
"subtotal": q.subtotal or 0,
|
||||
"discount": q.discount or 0,
|
||||
"shipping": q.shipping or 0,
|
||||
"total": q.total or q.subtotal or 0,
|
||||
"payment_terms": q.payment_terms or "",
|
||||
"delivery_terms": q.delivery_terms or "",
|
||||
"lead_time": q.lead_time or "",
|
||||
"notes": q.notes or "",
|
||||
}
|
||||
|
||||
pdf_bytes = pdf_generator.generate_quotation(data)
|
||||
|
||||
if pdf_bytes:
|
||||
upload_dir = settings.UPLOAD_DIR
|
||||
pdf_path = os.path.join(upload_dir, f"quotation_{quotation_id}.pdf")
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(pdf_bytes)
|
||||
q.pdf_url = pdf_path
|
||||
await db.flush()
|
||||
return {"success": True, "pdf_path": pdf_path, "quotation_id": str(q.id)}
|
||||
else:
|
||||
return {"error": "PDF generation failed (weasyprint not available)", "quotation_id": str(q.id)}
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_generate())
|
||||
|
||||
|
||||
def generate_pdf_text(quotation, items):
|
||||
from datetime import datetime
|
||||
|
||||
lines = [
|
||||
"=" * 60,
|
||||
f"QUOTATION",
|
||||
f"#{str(quotation.id)[:8].upper()}",
|
||||
"=" * 60,
|
||||
f"Date: {datetime.utcnow().strftime('%Y-%m-%d')}",
|
||||
]
|
||||
|
||||
if quotation.valid_until:
|
||||
lines.append(f"Valid Until: {quotation.valid_until}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"{'Item':<30} {'Qty':<8} {'Unit Price':<12} {'Total':<12}")
|
||||
lines.append("-" * 62)
|
||||
|
||||
for item in items:
|
||||
lines.append(
|
||||
f"{item.product_name:<30} {item.quantity:<8} ${item.unit_price:<10.2f} ${item.total_price:<10.2f}"
|
||||
)
|
||||
|
||||
lines.append("-" * 62)
|
||||
if quotation.subtotal:
|
||||
lines.append(f"{'Subtotal':>48} ${quotation.subtotal:<10.2f}")
|
||||
if quotation.discount:
|
||||
lines.append(f"{'Discount':>48} -${quotation.discount:<10.2f}")
|
||||
if quotation.shipping:
|
||||
lines.append(f"{'Shipping':>48} ${quotation.shipping:<10.2f}")
|
||||
lines.append(f"{'TOTAL':>48} ${quotation.total or quotation.subtotal or 0:<10.2f}")
|
||||
|
||||
lines.append("")
|
||||
if quotation.payment_terms:
|
||||
lines.append(f"Payment Terms: {quotation.payment_terms}")
|
||||
if quotation.delivery_terms:
|
||||
lines.append(f"Delivery Terms: {quotation.delivery_terms}")
|
||||
if quotation.lead_time:
|
||||
lines.append(f"Lead Time: {quotation.lead_time}")
|
||||
if quotation.notes:
|
||||
lines.append(f"Notes: {quotation.notes}")
|
||||
|
||||
lines.append("=" * 60)
|
||||
lines.append("Generated by TradeMate")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@shared_task
|
||||
def process_corpus_quality():
|
||||
from app.database import AsyncSessionLocal
|
||||
@@ -155,6 +172,71 @@ def process_corpus_quality():
|
||||
return asyncio.run(_process())
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
||||
def process_customer_import(self, user_id: str, records: list):
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.customer import CustomerService
|
||||
|
||||
async def _import():
|
||||
async with AsyncSessionLocal() as db:
|
||||
svc = CustomerService(db)
|
||||
imported = 0
|
||||
errors = []
|
||||
for i, record in enumerate(records):
|
||||
try:
|
||||
await svc.create_customer(user_id, record)
|
||||
imported += 1
|
||||
except Exception as e:
|
||||
errors.append(f"Row {i+2}: {str(e)}")
|
||||
return {"imported": imported, "total": len(records), "errors": errors}
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_import())
|
||||
|
||||
|
||||
@shared_task
|
||||
def run_daily_corpus_training():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.corpus_trainer import CorpusTrainer
|
||||
|
||||
async def _train():
|
||||
async with AsyncSessionLocal() as db:
|
||||
trainer = CorpusTrainer(db)
|
||||
result = await trainer.run_pipeline()
|
||||
logger.info(f"Daily corpus training complete: {result}")
|
||||
return result
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_train())
|
||||
|
||||
|
||||
@shared_task
|
||||
def update_customer_health_cache():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
from app.models.user import User
|
||||
from app.config import settings
|
||||
|
||||
async def _update():
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids = result.scalars().all()
|
||||
|
||||
svc = CustomerHealthService(db)
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
overview = await svc.get_health_overview(uid)
|
||||
scores = await svc.get_all_health_scores(uid)
|
||||
except Exception as e:
|
||||
logger.error(f"Health cache failed for user {uid}: {e}")
|
||||
|
||||
return f"Updated health cache for {len(user_ids)} users"
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_update())
|
||||
|
||||
|
||||
@shared_task
|
||||
def cleanup_old_sessions():
|
||||
import redis.asyncio as aioredis
|
||||
@@ -190,4 +272,20 @@ def send_followup_reminder(customer_id: str, user_id: str):
|
||||
return {"error": "Customer not found"}
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_send())
|
||||
return asyncio.run(_send())
|
||||
|
||||
|
||||
@shared_task
|
||||
def check_followup_engine():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.followup_engine import FollowupEngine
|
||||
|
||||
async def _check():
|
||||
async with AsyncSessionLocal() as db:
|
||||
engine = FollowupEngine(db)
|
||||
result = await engine.scan_and_followup()
|
||||
logger.info(f"Followup engine check complete: {result}")
|
||||
return result
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_check())
|
||||
@@ -1,19 +1,22 @@
|
||||
fastapi==0.79.0
|
||||
uvicorn==0.19.0
|
||||
fastapi==0.100.0
|
||||
uvicorn==0.23.2
|
||||
sqlalchemy==1.4.48
|
||||
asyncpg==0.27.0
|
||||
pydantic==1.10.12
|
||||
pydantic-settings==1.1.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-multipart==0.0.6
|
||||
redis==4.5.5
|
||||
celery==5.2.7
|
||||
httpx==0.23.3
|
||||
openai==0.27.8
|
||||
openai==1.12.0
|
||||
anthropic==0.8.1
|
||||
jinja2==3.1.2
|
||||
alembic==1.11.3
|
||||
sentry-sdk==2.3.1
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
pytest-cov==4.1.0
|
||||
weasyprint==60.2
|
||||
openpyxl==3.1.2
|
||||
edge-tts>=6.0.0
|
||||
@@ -17,8 +17,8 @@ class TestConfig:
|
||||
assert "translate" in settings.AI_ROUTING
|
||||
assert "reply" in settings.AI_ROUTING
|
||||
assert "marketing" in settings.AI_ROUTING
|
||||
assert settings.AI_ROUTING["translate"]["primary"] == "deepl"
|
||||
assert settings.AI_ROUTING["reply"]["primary"] == "openai"
|
||||
assert "extract" in settings.AI_ROUTING
|
||||
assert "primary" in settings.AI_ROUTING["translate"]
|
||||
|
||||
def test_free_tier_limits(self):
|
||||
assert settings.FREE_DAILY_TRANSLATE_CHARS == 5000
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from app.services.customer_health import CustomerHealthService
|
||||
|
||||
|
||||
class TestSilenceScore:
|
||||
def test_no_contact_returns_max_days(self):
|
||||
assert CustomerHealthService.silence_days(None) == 999
|
||||
|
||||
def test_recent_contact_returns_0_days(self):
|
||||
now = datetime.utcnow()
|
||||
assert CustomerHealthService.silence_days(now) == 0
|
||||
|
||||
def test_3_days_ago(self):
|
||||
dt = datetime.utcnow() - timedelta(days=3)
|
||||
assert CustomerHealthService.silence_days(dt) == 3
|
||||
|
||||
def test_silence_score_0_days(self):
|
||||
assert CustomerHealthService.calculate_silence_score(datetime.utcnow()) == 100
|
||||
|
||||
def test_silence_score_7_days(self):
|
||||
dt = datetime.utcnow() - timedelta(days=7)
|
||||
score = CustomerHealthService.calculate_silence_score(dt)
|
||||
assert score == 50
|
||||
|
||||
def test_silence_score_14_days(self):
|
||||
dt = datetime.utcnow() - timedelta(days=14)
|
||||
score = CustomerHealthService.calculate_silence_score(dt)
|
||||
assert score == 0
|
||||
|
||||
def test_silence_score_21_days_clamped(self):
|
||||
dt = datetime.utcnow() - timedelta(days=21)
|
||||
score = CustomerHealthService.calculate_silence_score(dt)
|
||||
assert score == 0
|
||||
|
||||
|
||||
class TestStatusWeight:
|
||||
def test_customer_status(self):
|
||||
assert CustomerHealthService.status_weight("customer") == 100
|
||||
|
||||
def test_negotiating_status(self):
|
||||
assert CustomerHealthService.status_weight("negotiating") == 70
|
||||
|
||||
def test_lead_status(self):
|
||||
assert CustomerHealthService.status_weight("lead") == 40
|
||||
|
||||
def test_lost_status(self):
|
||||
assert CustomerHealthService.status_weight("lost") == 10
|
||||
|
||||
def test_unknown_status_defaults(self):
|
||||
assert CustomerHealthService.status_weight("unknown") == 40
|
||||
|
||||
def test_none_status_defaults(self):
|
||||
assert CustomerHealthService.status_weight(None) == 40
|
||||
|
||||
|
||||
class TestGrade:
|
||||
def test_active_grade(self):
|
||||
assert CustomerHealthService.grade(100) == "active"
|
||||
assert CustomerHealthService.grade(80) == "active"
|
||||
assert CustomerHealthService.grade(85) == "active"
|
||||
|
||||
def test_watch_grade(self):
|
||||
assert CustomerHealthService.grade(79) == "watch"
|
||||
assert CustomerHealthService.grade(50) == "watch"
|
||||
assert CustomerHealthService.grade(65) == "watch"
|
||||
|
||||
def test_critical_grade(self):
|
||||
assert CustomerHealthService.grade(49) == "critical"
|
||||
assert CustomerHealthService.grade(0) == "critical"
|
||||
assert CustomerHealthService.grade(30) == "critical"
|
||||
|
||||
|
||||
class TestResponseScore:
|
||||
def test_both_none(self):
|
||||
r = CustomerHealthService.calc_response_score(None, None)
|
||||
assert r["score"] == 50
|
||||
assert r["trend"] == "stable"
|
||||
|
||||
def test_only_recent_exists(self):
|
||||
r = CustomerHealthService.calc_response_score(2.0, None)
|
||||
assert r["score"] == 90
|
||||
assert r["trend"] == "stable"
|
||||
|
||||
def test_improving_faster_response(self):
|
||||
r = CustomerHealthService.calc_response_score(2.0, 10.0)
|
||||
assert r["score"] == 90
|
||||
assert r["trend"] == "improving"
|
||||
|
||||
def test_declining_slower_response(self):
|
||||
r = CustomerHealthService.calc_response_score(10.0, 2.0)
|
||||
assert r["trend"] == "declining"
|
||||
|
||||
def test_fast_response_high_score(self):
|
||||
r = CustomerHealthService.calc_response_score(0.5, 5.0)
|
||||
assert r["score"] >= 95
|
||||
assert r["trend"] == "improving"
|
||||
|
||||
def test_very_slow_response_low_score(self):
|
||||
r = CustomerHealthService.calc_response_score(48.0, 2.0)
|
||||
assert r["score"] == 0
|
||||
assert r["trend"] == "declining"
|
||||
|
||||
|
||||
class TestSentimentScore:
|
||||
def test_empty_messages_neutral(self):
|
||||
r = CustomerHealthService.calc_sentiment_score([])
|
||||
assert r["score"] == 50
|
||||
assert r["label"] == "neutral"
|
||||
|
||||
def test_positive_message(self):
|
||||
r = CustomerHealthService.calc_sentiment_score(["yes I'm interested thanks"])
|
||||
assert r["score"] == 80
|
||||
assert r["label"] == "positive"
|
||||
|
||||
def test_negative_message(self):
|
||||
r = CustomerHealthService.calc_sentiment_score(["no not interested too expensive"])
|
||||
assert r["score"] == 20
|
||||
assert r["label"] == "negative"
|
||||
|
||||
def test_mixed_messages_neutral(self):
|
||||
r = CustomerHealthService.calc_sentiment_score(["good quality", "but price is high"])
|
||||
assert r["score"] == 50
|
||||
assert r["label"] == "neutral"
|
||||
|
||||
def test_more_positive_than_negative(self):
|
||||
r = CustomerHealthService.calc_sentiment_score([
|
||||
"great product",
|
||||
"yes please proceed",
|
||||
"but shipping is expensive",
|
||||
])
|
||||
assert r["score"] == 80
|
||||
assert r["label"] == "positive"
|
||||
|
||||
|
||||
class TestInquiryDepthScore:
|
||||
def test_empty_messages(self):
|
||||
r = CustomerHealthService.calc_inquiry_depth_score([])
|
||||
assert r["score"] == 0
|
||||
assert r["signal_count"] == 0
|
||||
|
||||
def test_no_signals(self):
|
||||
r = CustomerHealthService.calc_inquiry_depth_score(["hello", "how are you"])
|
||||
assert r["score"] == 0
|
||||
|
||||
def test_one_signal(self):
|
||||
r = CustomerHealthService.calc_inquiry_depth_score(["what is your price"])
|
||||
assert r["score"] == 50
|
||||
assert r["signal_count"] >= 1
|
||||
|
||||
def test_multiple_signals(self):
|
||||
r = CustomerHealthService.calc_inquiry_depth_score([
|
||||
"what is your MOQ and FOB price",
|
||||
"do you have certification",
|
||||
"what is the lead time",
|
||||
])
|
||||
assert r["score"] >= 75
|
||||
assert r["signal_count"] >= 3
|
||||
|
||||
def test_deduplicates_signals(self):
|
||||
r = CustomerHealthService.calc_inquiry_depth_score([
|
||||
"what is the price",
|
||||
"please send price and MOQ",
|
||||
])
|
||||
assert r["signal_count"] == 2
|
||||
|
||||
|
||||
class TestBusinessValueScore:
|
||||
def test_zero_value(self):
|
||||
r = CustomerHealthService.calc_business_value_score(0)
|
||||
assert r["score"] == 0
|
||||
|
||||
def test_small_value(self):
|
||||
r = CustomerHealthService.calc_business_value_score(500)
|
||||
assert r["score"] == 20
|
||||
|
||||
def test_medium_value(self):
|
||||
r = CustomerHealthService.calc_business_value_score(5000)
|
||||
assert r["score"] == 40
|
||||
|
||||
def test_large_value(self):
|
||||
r = CustomerHealthService.calc_business_value_score(50000)
|
||||
assert r["score"] == 80
|
||||
|
||||
def test_very_large_value(self):
|
||||
r = CustomerHealthService.calc_business_value_score(200000)
|
||||
assert r["score"] == 100
|
||||
|
||||
|
||||
class TestTotalScore:
|
||||
def test_perfect_health(self):
|
||||
dims = {
|
||||
"response_trend": {"score": 100},
|
||||
"sentiment": {"score": 100},
|
||||
"inquiry_depth": {"score": 100},
|
||||
"silence": {"score": 100},
|
||||
"business_value": {"score": 100},
|
||||
}
|
||||
r = CustomerHealthService.calc_total_score(dims)
|
||||
assert r["total_score"] == 100
|
||||
assert r["grade"] == "active"
|
||||
|
||||
def test_zero_health(self):
|
||||
dims = {
|
||||
"response_trend": {"score": 0},
|
||||
"sentiment": {"score": 0},
|
||||
"inquiry_depth": {"score": 0},
|
||||
"silence": {"score": 0},
|
||||
"business_value": {"score": 0},
|
||||
}
|
||||
r = CustomerHealthService.calc_total_score(dims)
|
||||
assert r["total_score"] == 0
|
||||
assert r["grade"] == "critical"
|
||||
|
||||
def test_mid_health(self):
|
||||
dims = {
|
||||
"response_trend": {"score": 60},
|
||||
"sentiment": {"score": 50},
|
||||
"inquiry_depth": {"score": 50},
|
||||
"silence": {"score": 40},
|
||||
"business_value": {"score": 50},
|
||||
}
|
||||
r = CustomerHealthService.calc_total_score(dims)
|
||||
assert 45 <= r["total_score"] <= 55
|
||||
|
||||
|
||||
class TestSuggestion:
|
||||
def test_active_suggestion(self):
|
||||
s = CustomerHealthService.suggestion("active", 1, "lead")
|
||||
assert "良好" in s
|
||||
|
||||
def test_watch_with_silence(self):
|
||||
s = CustomerHealthService.suggestion("watch", 5, "lead")
|
||||
assert "5天" in s
|
||||
assert "跟进" in s
|
||||
|
||||
def test_watch_no_silence(self):
|
||||
s = CustomerHealthService.suggestion("watch", 1, "lead")
|
||||
assert "关注" in s
|
||||
|
||||
def test_critical_lead(self):
|
||||
s = CustomerHealthService.suggestion("critical", 10, "lead")
|
||||
assert "10天" in s
|
||||
assert "跟进" in s
|
||||
|
||||
def test_critical_lost(self):
|
||||
s = CustomerHealthService.suggestion("critical", 20, "lost")
|
||||
assert "重新激活" in s
|
||||
|
||||
|
||||
class TestHealthOverview:
|
||||
def test_overview_empty(self):
|
||||
overview = CustomerHealthService._calculate_overview_static([])
|
||||
assert overview["total"] == 0
|
||||
assert overview["active"] == 0
|
||||
assert overview["watch"] == 0
|
||||
assert overview["critical"] == 0
|
||||
|
||||
def test_overview_mixed(self, monkeypatch):
|
||||
class Row:
|
||||
def __init__(self, status, days_ago):
|
||||
self.status = status
|
||||
self.last_contact_at = datetime.utcnow() - timedelta(days=days_ago)
|
||||
|
||||
rows = [
|
||||
Row("customer", 1),
|
||||
Row("lead", 7),
|
||||
Row("negotiating", 14),
|
||||
Row("lost", 30),
|
||||
]
|
||||
overview = CustomerHealthService._calculate_overview_static(rows)
|
||||
assert overview["total"] == 4
|
||||
assert overview["active"] == 1
|
||||
@@ -0,0 +1,95 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from app.core.security import create_access_token
|
||||
from app.models.user import User
|
||||
import uuid
|
||||
|
||||
|
||||
class TestAdminAPI:
|
||||
async def test_admin_dashboard_unauthorized(self, client: AsyncClient):
|
||||
response = await client.get("/api/v1/admin/dashboard")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_admin_dashboard_forbidden_non_admin(self, client: AsyncClient, test_user):
|
||||
token = create_access_token({"sub": str(test_user.id), "tier": "free", "role": "user"})
|
||||
response = await client.get(
|
||||
"/api/v1/admin/dashboard",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_admin_dashboard_success(self, client: AsyncClient, test_user):
|
||||
test_user.role = "admin"
|
||||
token = create_access_token({"sub": str(test_user.id), "tier": "free", "role": "admin"})
|
||||
response = await client.get(
|
||||
"/api/v1/admin/dashboard",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_users" in data
|
||||
assert "paid_users" in data
|
||||
|
||||
async def test_admin_list_users(self, client: AsyncClient, test_user):
|
||||
test_user.role = "admin"
|
||||
token = create_access_token({"sub": str(test_user.id), "tier": "free", "role": "admin"})
|
||||
response = await client.get(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
|
||||
async def test_admin_update_tier_forbidden_non_admin(self, client: AsyncClient, test_user):
|
||||
target_id = str(uuid.uuid4())
|
||||
token = create_access_token({"sub": str(test_user.id), "tier": "free", "role": "user"})
|
||||
response = await client.patch(
|
||||
f"/api/v1/admin/users/{target_id}/tier",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
json={"tier": "pro"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestRateLimit:
|
||||
async def test_health_not_rate_limited(self, client: AsyncClient):
|
||||
for _ in range(10):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_rate_limit_headers_present(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/customers", headers=auth_headers)
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Limit" in response.headers
|
||||
|
||||
|
||||
class TestUserRole:
|
||||
async def test_user_default_role(self, client: AsyncClient, test_user):
|
||||
assert test_user.role == "user"
|
||||
|
||||
async def test_user_info_contains_role(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "role" in data
|
||||
assert data["role"] == "user"
|
||||
|
||||
|
||||
class TestPrivacyTerms:
|
||||
async def test_privacy_page_exists(self):
|
||||
import os
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"uni-app", "src", "pages", "agreement", "privacy.vue",
|
||||
)
|
||||
assert os.path.exists(path), "privacy.vue not found"
|
||||
|
||||
async def test_terms_page_exists(self):
|
||||
import os
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"uni-app", "src", "pages", "agreement", "terms.vue",
|
||||
)
|
||||
assert os.path.exists(path), "terms.vue not found"
|
||||
@@ -0,0 +1,337 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from app.models.customer import Conversation, Message
|
||||
from app.models.quotation import Quotation, QuotationItem
|
||||
from app.models.user import Product
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestTranslateAPI:
|
||||
async def test_translate_unauthorized(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/translate",
|
||||
json={"text": "Hello", "target_lang": "zh"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_translate_success(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.translation.TranslationService.translate") as mock:
|
||||
mock.return_value = {
|
||||
"translated_text": "你好",
|
||||
"source_lang": "en",
|
||||
"provider_used": "mock",
|
||||
"from_cache": False,
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/translate",
|
||||
headers=auth_headers,
|
||||
json={"text": "Hello", "target_lang": "zh"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["translated_text"] == "你好"
|
||||
|
||||
async def test_translate_with_context(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.translation.TranslationService.translate") as mock:
|
||||
mock.return_value = {
|
||||
"translated_text": "FOB 上海 价格",
|
||||
"source_lang": "en",
|
||||
"provider_used": "mock",
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/translate",
|
||||
headers=auth_headers,
|
||||
json={"text": "FOB Shanghai price", "target_lang": "zh", "context": "trade"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestReplyAPI:
|
||||
async def test_reply_unauthorized(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/translate/reply",
|
||||
json={"inquiry": "How much?", "tone": "professional"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_reply_success(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.translation.TranslationService.generate_reply") as mock:
|
||||
mock.return_value = [
|
||||
{"reply": "Thank you for your inquiry.", "tone": "professional", "provider": "mock"},
|
||||
{"reply": "Thanks for reaching out!", "tone": "friendly", "provider": "mock"},
|
||||
]
|
||||
response = await client.post(
|
||||
"/api/v1/translate/reply",
|
||||
headers=auth_headers,
|
||||
json={"inquiry": "How much for 500 units?", "tone": "professional", "count": 2},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["suggestions"]) == 2
|
||||
assert data["count"] == 2
|
||||
|
||||
async def test_reply_with_context(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.translation.TranslationService.generate_reply") as mock:
|
||||
mock.return_value = [{"reply": "Our price is $10/unit.", "tone": "professional", "provider": "mock"}]
|
||||
response = await client.post(
|
||||
"/api/v1/translate/reply",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"inquiry": "Price?",
|
||||
"tone": "professional",
|
||||
"count": 1,
|
||||
"context": {"product": "Widget X", "price": "$10"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestExtractAPI:
|
||||
async def test_extract_unauthorized(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/translate/extract",
|
||||
json={"text": "I want 500pcs of red widgets FOB Shanghai", "extract_type": "inquiry"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_extract_success(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.translation.TranslationService.extract_info") as mock:
|
||||
mock.return_value = {
|
||||
"intent": "purchase",
|
||||
"product_interest": "widgets",
|
||||
"quantity": "500",
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/translate/extract",
|
||||
headers=auth_headers,
|
||||
json={"text": "I want 500pcs of red widgets FOB Shanghai", "extract_type": "inquiry"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "extracted" in data
|
||||
|
||||
|
||||
class TestTTSAPI:
|
||||
async def test_tts_get_unauthorized(self, client: AsyncClient):
|
||||
response = await client.get("/api/v1/translate/tts?text=hello&lang=en")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_tts_get_empty_text(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/translate/tts?text=&lang=en", headers=auth_headers)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestMarketingAPI:
|
||||
async def test_marketing_unauthorized(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/marketing/generate",
|
||||
json={"product_name": "Widget", "description": "A great widget", "target": "US buyers"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_marketing_success(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.marketing.MarketingService.generate") as mock:
|
||||
mock.return_value = [
|
||||
{"content": "Buy our widget!", "style": "professional", "provider": "mock"},
|
||||
]
|
||||
response = await client.post(
|
||||
"/api/v1/marketing/generate",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"product_name": "Widget X",
|
||||
"description": "High quality widget",
|
||||
"category": "tools",
|
||||
"target": "US importers",
|
||||
"style": "professional",
|
||||
"count": 1,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["count"] >= 1
|
||||
assert "results" in data
|
||||
|
||||
async def test_marketing_keywords(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.marketing.MarketingService.generate_keywords") as mock:
|
||||
mock.return_value = ["widget", "tool", "quality"]
|
||||
response = await client.post(
|
||||
"/api/v1/marketing/keywords",
|
||||
headers=auth_headers,
|
||||
json={"product_name": "Widget", "description": "A widget", "count": 5},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "keywords" in response.json()
|
||||
|
||||
|
||||
class TestProductAPI:
|
||||
async def test_create_product(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/products",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "Test Product",
|
||||
"description": "A test product",
|
||||
"category": "electronics",
|
||||
"price": "10.50",
|
||||
"price_unit": "USD",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Product"
|
||||
|
||||
async def test_list_products(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/products", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
|
||||
async def test_update_product(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
product = Product(
|
||||
user_id=test_user.id,
|
||||
name="Old Name",
|
||||
category="tools",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(product)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/products/{product.id}",
|
||||
headers=auth_headers,
|
||||
json={"name": "New Name", "price": "20.00"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "New Name"
|
||||
|
||||
async def test_delete_product(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
product = Product(user_id=test_user.id, name="To Delete")
|
||||
db_session.add(product)
|
||||
await db_session.commit()
|
||||
pid = product.id
|
||||
|
||||
response = await client.delete(f"/api/v1/products/{pid}", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await client.get(f"/api/v1/products/{pid}", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestQuotationAPI:
|
||||
async def test_create_quotation(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
from app.models.customer import Customer
|
||||
customer = Customer(user_id=test_user.id, name="Test Buyer")
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/quotations",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"customer_id": str(customer.id),
|
||||
"title": "Test Quote",
|
||||
"items": [
|
||||
{"product_name": "Widget", "quantity": 100, "unit_price": 10.0},
|
||||
],
|
||||
"currency": "USD",
|
||||
"payment_terms": "T/T",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["title"] == "Test Quote"
|
||||
assert len(data["items"]) == 1
|
||||
|
||||
async def test_list_quotations(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/quotations", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert "items" in response.json()
|
||||
|
||||
async def test_quotation_pdf_not_found(self, client: AsyncClient, auth_headers):
|
||||
import uuid
|
||||
response = await client.get(f"/api/v1/quotations/{uuid.uuid4()}/pdf", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_quotation_status_update(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
from app.models.customer import Customer
|
||||
customer = Customer(user_id=test_user.id, name="Status Test Buyer")
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
q = Quotation(user_id=test_user.id, customer_id=customer.id, title="Status Test", status="draft")
|
||||
db_session.add(q)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/quotations/{q.id}/status",
|
||||
headers=auth_headers,
|
||||
json={"status": "sent"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "sent"
|
||||
|
||||
|
||||
class TestAnalyticsAPI:
|
||||
async def test_overview(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/analytics/overview", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "customers" in data
|
||||
assert "translations" in data
|
||||
assert "quotations" in data
|
||||
assert "messages" in data
|
||||
assert "marketing" in data
|
||||
|
||||
async def test_customer_analytics(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/analytics/customers", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_marketing_analytics(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/analytics/marketing", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_events" in data
|
||||
|
||||
|
||||
class TestOnboardingAPI:
|
||||
async def test_onboarding_status(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/onboarding/status", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "completed" in data
|
||||
assert "product_count" in data
|
||||
|
||||
async def test_onboarding_create_product(self, client: AsyncClient, auth_headers):
|
||||
with patch("app.services.onboarding.OnboardingService.create_product") as mock:
|
||||
mock.return_value = {
|
||||
"id": "mock-id",
|
||||
"name": "Onboarded Product",
|
||||
"marketing_contents": [],
|
||||
"keywords": [],
|
||||
}
|
||||
response = await client.post(
|
||||
"/api/v1/onboarding/product",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "New Product",
|
||||
"description": "Desc",
|
||||
"category": "tools",
|
||||
"target": "US buyers",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Onboarded Product"
|
||||
|
||||
|
||||
class TestExportAPI:
|
||||
async def test_export_customers_csv(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/customers/export/csv", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv"
|
||||
assert "customers.csv" in response.headers["content-disposition"]
|
||||
|
||||
async def test_export_quotations_csv(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/quotations/export/csv", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv"
|
||||
@@ -0,0 +1,156 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from app.models.notification import Notification
|
||||
from app.models.feedback import Feedback
|
||||
|
||||
|
||||
class TestNotificationAPI:
|
||||
async def test_list_notifications_empty(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/notifications", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
assert data["items"] == []
|
||||
|
||||
async def test_unread_count_zero(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/notifications/unread-count", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["count"] == 0
|
||||
|
||||
async def test_create_and_list_notification(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
n = Notification(
|
||||
user_id=test_user.id,
|
||||
title="Test Title",
|
||||
content="Test Content",
|
||||
)
|
||||
db_session.add(n)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get("/api/v1/notifications", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["items"]) >= 1
|
||||
assert data["items"][0]["title"] == "Test Title"
|
||||
|
||||
async def test_mark_read(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
n = Notification(user_id=test_user.id, title="Read Test", content="Content")
|
||||
db_session.add(n)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/notifications/{n.id}/read",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
count_resp = await client.get("/api/v1/notifications/unread-count", headers=auth_headers)
|
||||
assert count_resp.json()["count"] == 0
|
||||
|
||||
async def test_mark_all_read(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
for i in range(3):
|
||||
db_session.add(Notification(user_id=test_user.id, title=f"Notif {i}", content="Content"))
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.post("/api/v1/notifications/read-all", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["count"] == 3
|
||||
|
||||
async def test_delete_notification(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
n = Notification(user_id=test_user.id, title="Delete Me", content="Content")
|
||||
db_session.add(n)
|
||||
await db_session.commit()
|
||||
nid = n.id
|
||||
|
||||
response = await client.delete(f"/api/v1/notifications/{nid}", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
|
||||
list_resp = await client.get("/api/v1/notifications", headers=auth_headers)
|
||||
ids = [item["id"] for item in list_resp.json()["items"]]
|
||||
assert str(nid) not in ids
|
||||
|
||||
async def test_delete_not_found(self, client: AsyncClient, auth_headers):
|
||||
import uuid
|
||||
response = await client.delete(
|
||||
f"/api/v1/notifications/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_unread_count_after_read(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
for i in range(2):
|
||||
db_session.add(Notification(user_id=test_user.id, title=f"Unread {i}", content="C"))
|
||||
await db_session.commit()
|
||||
|
||||
resp = await client.get("/api/v1/notifications/unread-count", headers=auth_headers)
|
||||
assert resp.json()["count"] == 2
|
||||
|
||||
async def test_unread_only_filter(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
n1 = Notification(user_id=test_user.id, title="Read", content="C", is_read=True)
|
||||
n2 = Notification(user_id=test_user.id, title="Unread", content="C")
|
||||
db_session.add_all([n1, n2])
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/notifications?unread_only=true",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
for item in response.json()["items"]:
|
||||
assert item["is_read"] is False
|
||||
|
||||
|
||||
class TestFeedbackAPI:
|
||||
async def test_submit_feedback(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/feedback",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"content": "Great app!",
|
||||
"category": "feature",
|
||||
"contact": "test@example.com",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
async def test_submit_feedback_minimal(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/feedback",
|
||||
headers=auth_headers,
|
||||
json={"content": "Bug report"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
async def test_submit_feedback_unauthorized(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/feedback",
|
||||
json={"content": "Test"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
class TestPaymentAPI:
|
||||
async def test_get_plans(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/payment/plans", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "plans" in data
|
||||
assert len(data["plans"]) >= 3
|
||||
|
||||
async def test_get_subscription_free(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/payment/subscription", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "plan" in data
|
||||
assert "status" in data
|
||||
|
||||
async def test_create_order(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/payment/create-order",
|
||||
headers=auth_headers,
|
||||
json={"plan": "pro"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "prepay_id" in data or "order_id" in data or "url" in data
|
||||
@@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from app.models.customer import Conversation, Message, Customer
|
||||
from app.models.preference import PreferenceAnalysis, MarketingEffect
|
||||
from app.models.user import Product
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from app.services.corpus_trainer import CorpusTrainer
|
||||
from app.models.corpus import CorpusEntry
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestCorpusTrainer:
|
||||
async def test_get_stats_empty(self, db_session):
|
||||
trainer = CorpusTrainer(db_session)
|
||||
stats = await trainer.get_stats()
|
||||
assert stats["total_entries"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
|
||||
async def test_get_stats_with_data(self, db_session):
|
||||
entries = [
|
||||
CorpusEntry(source_text="Hello", target_text="你好", task_type="translate", quality_score=0.8),
|
||||
CorpusEntry(source_text="Goodbye", target_text="再见", task_type="translate", quality_score=0.6),
|
||||
]
|
||||
for e in entries:
|
||||
db_session.add(e)
|
||||
await db_session.commit()
|
||||
|
||||
trainer = CorpusTrainer(db_session)
|
||||
stats = await trainer.get_stats()
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["by_task_type"]["translate"] == 2
|
||||
assert stats["high_quality"] == 1
|
||||
assert stats["low_quality"] == 0
|
||||
|
||||
async def test_score_entries(self, db_session):
|
||||
entries = [
|
||||
CorpusEntry(source_text="Hello world", target_text="你好世界", task_type="translate"),
|
||||
CorpusEntry(source_text="Hi", target_text="嗨", task_type="translate"),
|
||||
]
|
||||
for e in entries:
|
||||
db_session.add(e)
|
||||
await db_session.commit()
|
||||
|
||||
trainer = CorpusTrainer(db_session)
|
||||
result = await trainer.score_entries(batch_size=10)
|
||||
assert result["processed"] == 2
|
||||
assert result["updated"] == 2
|
||||
|
||||
for e in entries:
|
||||
await db_session.refresh(e)
|
||||
assert e.quality_score is not None
|
||||
assert 0.0 <= e.quality_score <= 1.0
|
||||
|
||||
async def test_deduplicate(self, db_session):
|
||||
from datetime import datetime
|
||||
e1 = CorpusEntry(
|
||||
source_text="Duplicate text", target_text="重复文本",
|
||||
task_type="translate", quality_score=0.8,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
e2 = CorpusEntry(
|
||||
source_text="Duplicate text", target_text="重复文本",
|
||||
task_type="translate", quality_score=0.7,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db_session.add_all([e1, e2])
|
||||
await db_session.commit()
|
||||
|
||||
trainer = CorpusTrainer(db_session)
|
||||
result = await trainer.deduplicate()
|
||||
assert result["duplicates_removed"] == 1
|
||||
|
||||
stats = await trainer.get_stats()
|
||||
assert stats["total_entries"] == 1
|
||||
|
||||
async def test_prune_low_quality(self, db_session):
|
||||
from datetime import timedelta
|
||||
old = datetime.utcnow() - timedelta(days=100)
|
||||
entry = CorpusEntry(
|
||||
source_text="x", target_text="y",
|
||||
task_type="translate", quality_score=0.1,
|
||||
created_at=old, usage_count=0,
|
||||
)
|
||||
db_session.add(entry)
|
||||
await db_session.commit()
|
||||
|
||||
trainer = CorpusTrainer(db_session)
|
||||
result = await trainer.prune_low_quality(min_score=0.2, max_age_days=30)
|
||||
assert result["pruned"] == 1
|
||||
|
||||
stats = await trainer.get_stats()
|
||||
assert stats["total_entries"] == 0
|
||||
|
||||
async def test_run_pipeline(self, db_session):
|
||||
trainer = CorpusTrainer(db_session)
|
||||
result = await trainer.run_pipeline()
|
||||
assert "deduplication" in result
|
||||
assert "scoring" in result
|
||||
assert "embeddings" in result
|
||||
assert "pruning" in result
|
||||
assert "stats" in result
|
||||
|
||||
def test_calculate_quality_score_with_rating(self, db_session):
|
||||
trainer = CorpusTrainer(db_session)
|
||||
entry = CorpusEntry(
|
||||
source_text="Good source text with enough length",
|
||||
target_text="Good target text with enough length",
|
||||
task_type="translate",
|
||||
user_rating=4,
|
||||
)
|
||||
score = trainer._calculate_quality_score(entry)
|
||||
assert 0.7 <= score <= 1.0
|
||||
|
||||
def test_calculate_quality_score_short_text(self, db_session):
|
||||
trainer = CorpusTrainer(db_session)
|
||||
entry = CorpusEntry(
|
||||
source_text="ab", target_text="cd",
|
||||
task_type="translate",
|
||||
)
|
||||
score = trainer._calculate_quality_score(entry)
|
||||
assert score < 0.5
|
||||
|
||||
def test_calculate_quality_score_with_usage(self, db_session):
|
||||
trainer = CorpusTrainer(db_session)
|
||||
entry = CorpusEntry(
|
||||
source_text="Good source text here with proper length",
|
||||
target_text="Good target text here with proper length",
|
||||
task_type="translate",
|
||||
usage_count=10,
|
||||
)
|
||||
score = trainer._calculate_quality_score(entry)
|
||||
assert score >= 0.6
|
||||
|
||||
async def test_embedding_generation_skipped_without_key(self, db_session):
|
||||
from app.config import settings
|
||||
original = settings.OPENAI_API_KEY
|
||||
settings.OPENAI_API_KEY = None
|
||||
|
||||
trainer = CorpusTrainer(db_session)
|
||||
embedding = await trainer._generate_embedding("test")
|
||||
assert embedding is None
|
||||
|
||||
settings.OPENAI_API_KEY = original
|
||||
Reference in New Issue
Block a user