Initial commit: TradeMate 外贸小助手 MVP
项目结构: - backend/ Python FastAPI 后端 - uni-app/ uni-app跨端前端 - docs/ 设计文档 - docker-compose.yml Docker编排 - nginx/scripts/systemd 运维配置 已完成功能: - 用户认证 (JWT) - 智能翻译 + 回复建议 - 营销素材生成 - 客户管理 + 沉默检测 - 报价单管理 - 产品库管理 - 汇率换算 - 推送通知 (uni-push) - WhatsApp Webhook框架 - Celery定时任务
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
# 应用配置
|
||||
APP_NAME=TradeMate
|
||||
SECRET_KEY=change-this-to-a-random-secret-key
|
||||
JWT_ALGORITHM=HS256
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# 数据库
|
||||
DATABASE_URL=postgresql+asyncpg://tradmate:tradmate@localhost:5432/tradmate
|
||||
|
||||
# Redis
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# Celery
|
||||
CELERY_BROKER_URL=redis://localhost:6379/1
|
||||
CELERY_RESULT_BACKEND=redis://localhost:6379/2
|
||||
|
||||
# AI 提供商(至少配置一个)
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
DEEPL_API_KEY=
|
||||
|
||||
# 本地模型(可选)
|
||||
LOCAL_MODEL_ENABLED=false
|
||||
LOCAL_MODEL_URL=http://localhost:8001
|
||||
|
||||
# WhatsApp Cloud API
|
||||
WHATSAPP_API_TOKEN=
|
||||
WHATSAPP_PHONE_NUMBER_ID=
|
||||
WHATSAPP_WEBHOOK_VERIFY_TOKEN=
|
||||
|
||||
# 微信小程序
|
||||
WECHAT_APP_ID=
|
||||
WECHAT_APP_SECRET=
|
||||
|
||||
# 汇率 API(免费层即可)
|
||||
EXCHANGE_RATE_API_KEY=
|
||||
|
||||
# 文件存储
|
||||
UPLOAD_DIR=./uploads
|
||||
MAX_UPLOAD_SIZE=10485760
|
||||
|
||||
# URL
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
BACKEND_URL=http://localhost:8000
|
||||
@@ -0,0 +1,20 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p uploads
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,39 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
sqlalchemy.url = postgresql+asyncpg://tradmate:tradmate@localhost:5432/tradmate
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
from app.database import Base
|
||||
from app.models import User, Product, Customer, Conversation, Message, Quotation, QuotationItem, CorpusEntry
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,189 @@
|
||||
"""initial schema
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-05-08
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = '001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table('users',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('wechat_openid', sa.String(length=255), nullable=True),
|
||||
sa.Column('phone', sa.String(length=20), nullable=True),
|
||||
sa.Column('username', sa.String(length=100), nullable=True),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('tier', sa.String(length=50), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_phone'), 'users', ['phone'], unique=True)
|
||||
op.create_index(op.f('ix_users_wechat_openid'), 'users', ['wechat_openid'], unique=True)
|
||||
|
||||
op.create_table('products',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('name_en', sa.String(length=255), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('description_en', sa.Text(), nullable=True),
|
||||
sa.Column('category', sa.String(length=100), nullable=True),
|
||||
sa.Column('price', sa.String(length=50), nullable=True),
|
||||
sa.Column('price_unit', sa.String(length=20), nullable=True),
|
||||
sa.Column('moq', sa.String(length=50), nullable=True),
|
||||
sa.Column('keywords', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('specifications', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('images', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_products_user_id'), 'products', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('customers',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('company', sa.String(length=255), nullable=True),
|
||||
sa.Column('country', sa.String(length=100), nullable=True),
|
||||
sa.Column('phone', sa.String(length=50), nullable=True),
|
||||
sa.Column('email', sa.String(length=255), nullable=True),
|
||||
sa.Column('whatsapp_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('source', sa.String(length=100), nullable=True),
|
||||
sa.Column('tags', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('notes', sa.Text(), nullable=True),
|
||||
sa.Column('preference', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('last_contact_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('silence_started_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('next_followup_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('estimated_value', sa.String(length=50), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_customers_user_id'), 'customers', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('conversations',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('customer_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('channel', sa.String(length=50), nullable=True),
|
||||
sa.Column('topic', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('message_count', sa.Integer(), nullable=True),
|
||||
sa.Column('last_message_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_conversations_customer_id'), 'conversations', ['customer_id'], unique=False)
|
||||
op.create_index(op.f('ix_conversations_user_id'), 'conversations', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('messages',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('conversation_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('direction', sa.String(length=20), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('content_translated', sa.Text(), nullable=True),
|
||||
sa.Column('content_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('ai_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('selected_suggestion', sa.Integer(), nullable=True),
|
||||
sa.Column('user_edited', sa.Text(), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_messages_conversation_id'), 'messages', ['conversation_id'], unique=False)
|
||||
|
||||
op.create_table('quotations',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('customer_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('title', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.String(length=50), nullable=True),
|
||||
sa.Column('currency', sa.String(length=10), nullable=True),
|
||||
sa.Column('exchange_rate', sa.Float(), nullable=True),
|
||||
sa.Column('payment_terms', sa.String(length=255), nullable=True),
|
||||
sa.Column('delivery_terms', sa.String(length=255), nullable=True),
|
||||
sa.Column('lead_time', sa.String(length=100), nullable=True),
|
||||
sa.Column('valid_until', sa.String(length=100), nullable=True),
|
||||
sa.Column('subtotal', sa.Float(), nullable=True),
|
||||
sa.Column('discount', sa.Float(), nullable=True),
|
||||
sa.Column('shipping', sa.Float(), nullable=True),
|
||||
sa.Column('total', sa.Float(), nullable=True),
|
||||
sa.Column('notes', sa.Text(), nullable=True),
|
||||
sa.Column('pdf_url', sa.Text(), nullable=True),
|
||||
sa.Column('sent_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['customer_id'], ['customers.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_quotations_user_id'), 'quotations', ['user_id'], unique=False)
|
||||
|
||||
op.create_table('quotation_items',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('quotation_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('product_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('quantity', sa.Integer(), nullable=False),
|
||||
sa.Column('unit_price', sa.Float(), nullable=False),
|
||||
sa.Column('total_price', sa.Float(), nullable=True),
|
||||
sa.Column('unit', sa.String(length=50), nullable=True),
|
||||
sa.ForeignKeyConstraint(['quotation_id'], ['quotations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_quotation_items_quotation_id'), 'quotation_items', ['quotation_id'], unique=False)
|
||||
|
||||
op.create_table('corpus_entries',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('source_text', sa.Text(), nullable=False),
|
||||
sa.Column('target_text', sa.Text(), nullable=False),
|
||||
sa.Column('source_lang', sa.String(length=20), nullable=True),
|
||||
sa.Column('target_lang', sa.String(length=20), nullable=True),
|
||||
sa.Column('task_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('domain', sa.String(length=100), nullable=True),
|
||||
sa.Column('provider_used', sa.String(length=50), nullable=True),
|
||||
sa.Column('quality_score', sa.Float(), nullable=True),
|
||||
sa.Column('user_edited', sa.Boolean(), nullable=True),
|
||||
sa.Column('user_rating', sa.Integer(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=True),
|
||||
sa.Column('embedding', postgresql.Vector(length=768), nullable=True),
|
||||
sa.Column('metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('corpus_entries')
|
||||
op.drop_table('quotation_items')
|
||||
op.drop_table('quotations')
|
||||
op.drop_table('messages')
|
||||
op.drop_table('conversations')
|
||||
op.drop_table('customers')
|
||||
op.drop_table('products')
|
||||
op.drop_table('users')
|
||||
@@ -0,0 +1,3 @@
|
||||
from .router import get_ai_router
|
||||
|
||||
__all__ = ["get_ai_router"]
|
||||
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
@abstractmethod
|
||||
async def translate(
|
||||
self, text: str, source_lang: Optional[str], target_lang: str,
|
||||
context: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reply(
|
||||
self, inquiry: str, context: Optional[Dict[str, Any]] = None,
|
||||
tone: str = "professional",
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_marketing(
|
||||
self, product_info: Dict[str, Any], target: str,
|
||||
style: str = "professional", language: str = "en",
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def extract_info(
|
||||
self, text: str, schema: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return False
|
||||
@@ -0,0 +1,6 @@
|
||||
from .openai import OpenAIProvider
|
||||
from .claude import ClaudeProvider
|
||||
from .deepl import DeepLProvider
|
||||
from .local import LocalProvider
|
||||
|
||||
__all__ = ["OpenAIProvider", "ClaudeProvider", "DeepLProvider", "LocalProvider"]
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from anthropic import AsyncAnthropic
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"marketing": "You are a world-class copywriter for international trade. Write persuasive, "
|
||||
"culturally-adapted marketing content that converts. You excel at storytelling "
|
||||
"and emotional appeal in business contexts.",
|
||||
"reply": "You are a senior international sales representative with 20 years of experience. "
|
||||
"Your replies are warm, professional, and strategically move the conversation "
|
||||
"toward closing the deal.",
|
||||
"translate": "You are a professional translator specializing in trade documents. "
|
||||
"Preserve all numbers, terms, and formatting. Translate meaning, not words.",
|
||||
"extract": "Extract structured data from text. Return ONLY valid JSON.",
|
||||
}
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "claude-sonnet-4-20250514"):
|
||||
self.client = AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
self._name = f"claude-sonnet"
|
||||
self._pricing = {"input": 0.003, "output": 0.015}
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: {context}"
|
||||
prompt = f"Translate to {target_lang}:\n\n{text}"
|
||||
content = await self._call(system, prompt)
|
||||
return {"translated_text": content, "provider": self.name}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"]
|
||||
context_str = ""
|
||||
if context:
|
||||
for k, v in context.items():
|
||||
if v:
|
||||
context_str += f"{k}: {v}\n"
|
||||
prompt = f"{context_str}\nCustomer says:\n{inquiry}\n\nYour reply ({tone} tone):"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"]
|
||||
info = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product:\n{info}\n\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nWrite marketing copy:"
|
||||
content = await self._call(system, prompt, max_tokens=1500)
|
||||
return {"content": content, "provider": self.name}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
prompt = f"Schema:\n{json.dumps(schema, indent=2)}\n\nText:\n{text}\n\nJSON:"
|
||||
content = await self._call(system, prompt, max_tokens=1000)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "error": "parse_failed"}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000) -> str:
|
||||
resp = await self.client.messages.create(
|
||||
model=self.model,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
)
|
||||
return resp.content[0].text
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return (self._pricing["input"] + self._pricing["output"]) / 2
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
@@ -0,0 +1,51 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import httpx
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
class DeepLProvider(AIProvider):
|
||||
def __init__(self, api_key: str, endpoint: str = "https://api.deepl.com/v2"):
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self._name = "deepl"
|
||||
self._cost_per_char = 0.000006
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
params = {
|
||||
"auth_key": self.api_key,
|
||||
"text": text,
|
||||
"target_lang": target_lang.upper()[:2],
|
||||
}
|
||||
if source_lang and source_lang != "auto":
|
||||
params["source_lang"] = source_lang.upper()[:2]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(f"{self.endpoint}/translate", data=params, timeout=15)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
t = data["translations"][0]
|
||||
return {
|
||||
"translated_text": t["text"],
|
||||
"provider": self.name,
|
||||
"detected_source_lang": t.get("detected_source_language", source_lang),
|
||||
"char_count": len(text),
|
||||
"cost": len(text) * self._cost_per_char,
|
||||
}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support reply generation")
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support marketing generation")
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
raise NotImplementedError("DeepL does not support info extraction")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return self._cost_per_char * 1000
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json, httpx
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
class LocalProvider(AIProvider):
|
||||
def __init__(self, model_url: str = "http://localhost:8001", model_name: str = "gemma-3-8b"):
|
||||
self.model_url = model_url.rstrip("/")
|
||||
self.model_name = model_name
|
||||
self._name = f"local-{model_name}"
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
prompt = f"Translate{ f' from {source_lang}' if source_lang else ''} to {target_lang}:\n{text}\n\nTranslation:"
|
||||
result = await self._generate(prompt)
|
||||
return {"translated_text": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
ctx = ""
|
||||
if context:
|
||||
ctx = "\n".join(f"{k}: {v}" for k, v in context.items() if v)
|
||||
prompt = f"{ctx}\nCustomer: {inquiry}\n\nWrite a {tone} reply:"
|
||||
result = await self._generate(prompt)
|
||||
return {"reply": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
info = json.dumps(product_info, ensure_ascii=False)
|
||||
prompt = f"Product: {info}\nTarget: {target}\nStyle: {style}\nLanguage: {language}\n\nMarketing copy:"
|
||||
result = await self._generate(prompt, max_tokens=800)
|
||||
return {"content": result, "provider": self.name, "cost": 0.0}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
prompt = f"Extract JSON from text matching schema:\nSchema: {json.dumps(schema)}\n\nText: {text}\n\nJSON:"
|
||||
result = await self._generate(prompt, max_tokens=500)
|
||||
try:
|
||||
return {"data": json.loads(result), "confidence": 0.7, "provider": self.name, "cost": 0.0}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "cost": 0.0, "error": "parse_failed"}
|
||||
|
||||
async def _generate(self, prompt: str, max_tokens: int = 500) -> str:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.model_url}/v1/completions",
|
||||
json={"model": self.model_name, "prompt": prompt, "max_tokens": max_tokens, "temperature": 0.7, "stream": False},
|
||||
timeout=60,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["text"].strip()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
return 0.0
|
||||
@@ -0,0 +1,102 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
from app.ai.base import AIProvider
|
||||
|
||||
|
||||
SYSTEM_PROMPTS = {
|
||||
"translate": "You are a professional translator specialized in foreign trade and e-commerce. "
|
||||
"Accurately translate business terms like MOQ, FOB, CIF, lead time, etc. "
|
||||
"Return ONLY the translated text, no explanations.",
|
||||
"reply": "You are an experienced foreign trade sales expert. Write professional, "
|
||||
"clear business replies. Be concise but warm. Include relevant details "
|
||||
"naturally. Return ONLY the reply text, no explanations.",
|
||||
"marketing": "You are a creative copywriter for international trade. Write compelling "
|
||||
"marketing content that drives action. Adapt to the target audience's culture. "
|
||||
"Return ONLY the copy, no explanations.",
|
||||
"extract": "You extract structured data from text. Return ONLY valid JSON matching the requested schema.",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o"):
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
self._name = f"openai-{model}"
|
||||
self._pricing = {
|
||||
"gpt-4o": {"input": 0.01, "output": 0.03},
|
||||
"gpt-4o-mini": {"input": 0.0015, "output": 0.006},
|
||||
}
|
||||
self._cheap_model = "gpt-4o-mini" if model == "gpt-4o" else model
|
||||
|
||||
async def translate(self, text: str, source_lang: Optional[str], target_lang: str, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["translate"]
|
||||
if context:
|
||||
system += f"\nContext: this is about {context}"
|
||||
if source_lang and source_lang != "auto":
|
||||
system += f"\nSource language: {source_lang}"
|
||||
|
||||
content = await self._call(system, f"Translate to {target_lang}:\n\n{text}", model=self._cheap_model)
|
||||
return {"translated_text": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["reply"] + f"\nTone: {tone}"
|
||||
|
||||
context_str = ""
|
||||
if context:
|
||||
if context.get("product"):
|
||||
context_str += f"Product: {context['product']}\n"
|
||||
if context.get("price"):
|
||||
context_str += f"Price: {context['price']}\n"
|
||||
if context.get("customer_history"):
|
||||
context_str += f"Customer history: {context['customer_history']}\n"
|
||||
if context.get("conversation_history"):
|
||||
context_str += f"Previous messages: {context['conversation_history']}\n"
|
||||
|
||||
prompt = f"{context_str}\nCustomer inquiry:\n{inquiry}\n\nWrite a reply:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"reply": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def generate_marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["marketing"] + f"\nStyle: {style}\nTarget audience: {target}\nLanguage: {language}"
|
||||
|
||||
product_str = json.dumps(product_info, ensure_ascii=False, indent=2)
|
||||
prompt = f"Product information:\n{product_str}\n\nGenerate marketing copy:"
|
||||
content = await self._call(system, prompt)
|
||||
return {"content": content, "provider": self.name, "model": self.model}
|
||||
|
||||
async def extract_info(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
system = SYSTEM_PROMPTS["extract"]
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
prompt = f"Schema:\n{schema_str}\n\nText:\n{text}\n\nExtracted JSON:"
|
||||
content = await self._call(system, prompt, response_format={"type": "json_object"})
|
||||
try:
|
||||
data = json.loads(content)
|
||||
return {"data": data, "confidence": 0.9, "provider": self.name}
|
||||
except json.JSONDecodeError:
|
||||
return {"data": {}, "confidence": 0.0, "provider": self.name, "error": "parse_failed"}
|
||||
|
||||
async def _call(self, system: str, prompt: str, max_tokens: int = 1000, response_format: Optional[Dict] = None, model: Optional[str] = None) -> str:
|
||||
kwargs = {
|
||||
"model": model or self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
resp = await self.client.chat.completions.create(**kwargs)
|
||||
return resp.choices[0].message.content
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cost_per_1k_tokens(self) -> float:
|
||||
p = self._pricing.get(self.model, {"input": 0.01, "output": 0.03})
|
||||
return (p["input"] + p["output"]) / 2
|
||||
@@ -0,0 +1,110 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.base import AIProvider
|
||||
from app.ai.providers import OpenAIProvider, ClaudeProvider, DeepLProvider, LocalProvider
|
||||
from app.config import settings
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIRouter:
|
||||
def __init__(self):
|
||||
self.providers: Dict[str, AIProvider] = {}
|
||||
self.routing_rules = settings.AI_ROUTING
|
||||
self.corpus = TradeCorpus()
|
||||
self._init_providers()
|
||||
|
||||
def _init_providers(self):
|
||||
if settings.OPENAI_API_KEY:
|
||||
try:
|
||||
self.providers["openai"] = OpenAIProvider(api_key=settings.OPENAI_API_KEY)
|
||||
logger.info("OpenAI provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenAI init failed: {e}")
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
try:
|
||||
self.providers["anthropic"] = ClaudeProvider(api_key=settings.ANTHROPIC_API_KEY)
|
||||
logger.info("Claude provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Claude init failed: {e}")
|
||||
|
||||
if settings.DEEPL_API_KEY:
|
||||
try:
|
||||
self.providers["deepl"] = DeepLProvider(api_key=settings.DEEPL_API_KEY)
|
||||
logger.info("DeepL provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"DeepL init failed: {e}")
|
||||
|
||||
if settings.LOCAL_MODEL_ENABLED:
|
||||
try:
|
||||
self.providers["local"] = LocalProvider(model_url=settings.LOCAL_MODEL_URL)
|
||||
logger.info("Local provider ready")
|
||||
except Exception as e:
|
||||
logger.warning(f"Local init failed: {e}")
|
||||
|
||||
def get_providers_for_task(self, task_type: str) -> List[AIProvider]:
|
||||
rules = self.routing_rules.get(
|
||||
task_type,
|
||||
{"primary": "openai", "fallback": ["local"]},
|
||||
)
|
||||
ordered = []
|
||||
seen = set()
|
||||
|
||||
primary = rules.get("primary")
|
||||
if primary and primary in self.providers:
|
||||
ordered.append(self.providers[primary])
|
||||
seen.add(primary)
|
||||
|
||||
for name in rules.get("fallback", []):
|
||||
if name in self.providers and name not in seen:
|
||||
ordered.append(self.providers[name])
|
||||
seen.add(name)
|
||||
|
||||
if not ordered:
|
||||
ordered = list(self.providers.values())
|
||||
logger.warning(f"No preferred providers for '{task_type}', using all available")
|
||||
|
||||
return ordered
|
||||
|
||||
async def execute(self, task_type: str, method: str, *args, **kwargs) -> Dict[str, Any]:
|
||||
providers = self.get_providers_for_task(task_type)
|
||||
last_error = None
|
||||
|
||||
for provider in providers:
|
||||
try:
|
||||
method_fn = getattr(provider, method)
|
||||
result = await method_fn(*args, **kwargs)
|
||||
result["provider_used"] = provider.name
|
||||
return result
|
||||
except NotImplementedError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"{provider.name} failed for {task_type}: {e}")
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
raise Exception(f"All providers failed for '{task_type}'. Last error: {last_error}")
|
||||
|
||||
async def translate(self, text: str, target_lang: str, source_lang: Optional[str] = None, context: Optional[str] = None) -> Dict[str, Any]:
|
||||
return await self.execute("translate", "translate", text, source_lang, target_lang, context)
|
||||
|
||||
async def reply(self, inquiry: str, context: Optional[Dict[str, Any]] = None, tone: str = "professional") -> Dict[str, Any]:
|
||||
return await self.execute("reply", "reply", inquiry, context, tone)
|
||||
|
||||
async def marketing(self, product_info: Dict[str, Any], target: str, style: str = "professional", language: str = "en") -> Dict[str, Any]:
|
||||
return await self.execute("marketing", "generate_marketing", product_info, target, style, language)
|
||||
|
||||
async def extract(self, text: str, schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return await self.execute("extract", "extract_info", text, schema)
|
||||
|
||||
|
||||
_router_instance = None
|
||||
|
||||
|
||||
def get_ai_router() -> AIRouter:
|
||||
global _router_instance
|
||||
if _router_instance is None:
|
||||
_router_instance = AIRouter()
|
||||
return _router_instance
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy import select, text
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TradeCorpus:
|
||||
def __init__(self):
|
||||
self._ready = False
|
||||
|
||||
async def record(
|
||||
self,
|
||||
source_text: str,
|
||||
target_text: str,
|
||||
task_type: str,
|
||||
provider: str,
|
||||
source_lang: Optional[str] = None,
|
||||
target_lang: Optional[str] = None,
|
||||
quality_score: float = 0.5,
|
||||
user_edited: bool = False,
|
||||
metadata: Optional[Dict] = None,
|
||||
):
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
entry = CorpusEntry(
|
||||
source_text=source_text[:2000],
|
||||
target_text=target_text[:2000],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
task_type=task_type,
|
||||
provider_used=provider,
|
||||
quality_score=quality_score,
|
||||
user_edited=user_edited,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to record corpus entry: {e}")
|
||||
|
||||
async def find_similar(self, text: str, task_type: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(CorpusEntry)
|
||||
.where(CorpusEntry.task_type == task_type)
|
||||
.where(CorpusEntry.quality_score >= 0.6)
|
||||
.order_by(CorpusEntry.quality_score.desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"source": e.source_text,
|
||||
"target": e.target_text,
|
||||
"score": e.quality_score,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Corpus search failed: {e}")
|
||||
return []
|
||||
|
||||
async def rate_entry(self, entry_id: str, rating: int):
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
result = await session.execute(
|
||||
select(CorpusEntry).where(CorpusEntry.id == entry_id)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry:
|
||||
entry.user_rating = rating
|
||||
entry.quality_score = rating / 5.0
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rate corpus entry: {e}")
|
||||
@@ -0,0 +1,153 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
username: str = ""
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
user: dict
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(data: RegisterRequest, db: Annotated[AsyncSession, Depends(get_db)]):
|
||||
existing = await db.execute(select(User).where(User.phone == data.phone))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Phone already registered")
|
||||
|
||||
user = User(
|
||||
phone=data.phone,
|
||||
username=data.username or data.phone,
|
||||
password_hash=hash_password(data.password),
|
||||
tier="free",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
form: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
result = await db.execute(select(User).where(User.phone == form.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(form.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=create_access_token({"sub": str(user.id), "tier": user.tier}),
|
||||
refresh_token=create_refresh_token({"sub": str(user.id)}),
|
||||
user={
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh(data: RefreshRequest):
|
||||
payload = decode_token(data.refresh_token)
|
||||
if not payload or payload.get("type") != "refresh":
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
return {
|
||||
"access_token": create_access_token({"sub": payload["sub"]}),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(
|
||||
authorization: str = None,
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"phone": user.phone,
|
||||
"username": user.username,
|
||||
"tier": user.tier,
|
||||
"settings": user.settings,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
preferred_translate_provider: str = None
|
||||
reply_tone: str = None
|
||||
timezone: str = None
|
||||
languages: list = None
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
async def update_settings(
|
||||
data: SettingsUpdate,
|
||||
authorization: str = None,
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == payload["sub"]))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
settings = user.settings or {}
|
||||
for key, value in data.dict(exclude_unset=True).items():
|
||||
if value is not None:
|
||||
settings[key] = value
|
||||
|
||||
user.settings = settings
|
||||
await db.flush()
|
||||
|
||||
return {"settings": user.settings}
|
||||
@@ -0,0 +1,99 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from app.database import get_db
|
||||
from app.services.customer import CustomerService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_customers(
|
||||
status: Optional[str] = None,
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
return await service.list_customers(user_id, status, page, size)
|
||||
|
||||
|
||||
@router.get("/silent")
|
||||
async def get_silent(
|
||||
days: int = Query(3, ge=1),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
customers = await service.get_silent_customers(user_id, days)
|
||||
return {
|
||||
"customers": customers,
|
||||
"count": len(customers),
|
||||
"silence_days": days,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{customer_id}")
|
||||
async def get_customer(
|
||||
customer_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
customer = await service.get_customer(user_id, customer_id)
|
||||
if not customer:
|
||||
raise HTTPException(status_code=404, detail="Customer not found")
|
||||
return customer
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_customer(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
customer = await service.create_customer(user_id, data)
|
||||
return customer
|
||||
|
||||
|
||||
@router.patch("/{customer_id}")
|
||||
async def update_customer(
|
||||
customer_id: str,
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
customer = await service.update_customer(user_id, customer_id, data)
|
||||
if not customer:
|
||||
raise HTTPException(status_code=404, detail="Customer not found")
|
||||
return customer
|
||||
|
||||
|
||||
@router.delete("/{customer_id}")
|
||||
async def delete_customer(
|
||||
customer_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
deleted = await service.delete_customer(user_id, customer_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Customer not found")
|
||||
return {"message": "Customer deleted"}
|
||||
|
||||
|
||||
@router.get("/{customer_id}/conversation")
|
||||
async def get_conversation(
|
||||
customer_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(50, ge=1, le=200),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = CustomerService(db)
|
||||
return await service.get_conversation(user_id, customer_id, page, size)
|
||||
@@ -0,0 +1,13 @@
|
||||
from fastapi import HTTPException, Depends
|
||||
from app.core.security import decode_token
|
||||
|
||||
|
||||
async def get_current_user_id(authorization: str = None) -> str:
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing or invalid token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload.get("sub")
|
||||
@@ -0,0 +1,54 @@
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ExchangeRateResponse(BaseModel):
|
||||
from_currency: str
|
||||
to_currency: str
|
||||
rate: float
|
||||
updated_at: str
|
||||
|
||||
|
||||
EXCHANGE_RATES = {
|
||||
("USD", "CNY"): 7.24,
|
||||
("EUR", "CNY"): 7.85,
|
||||
("GBP", "CNY"): 9.15,
|
||||
("CNY", "USD"): 0.138,
|
||||
("USD", "EUR"): 0.92,
|
||||
("EUR", "USD"): 1.09,
|
||||
("GBP", "USD"): 1.27,
|
||||
("USD", "GBP"): 0.79,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/convert")
|
||||
async def convert_currency(
|
||||
from_currency: str = "USD",
|
||||
to_currency: str = "CNY",
|
||||
amount: float = 1.0,
|
||||
):
|
||||
rate = EXCHANGE_RATES.get((from_currency, to_currency), 1.0)
|
||||
return {
|
||||
"from_currency": from_currency,
|
||||
"to_currency": to_currency,
|
||||
"amount": amount,
|
||||
"converted": round(amount * rate, 2),
|
||||
"rate": rate,
|
||||
"updated_at": "2026-05-08T00:00:00Z",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/rates")
|
||||
async def get_rates(base: str = "USD"):
|
||||
rates = {}
|
||||
for (from_curr, to_curr), rate in EXCHANGE_RATES.items():
|
||||
if from_curr == base:
|
||||
rates[to_curr] = rate
|
||||
|
||||
return {
|
||||
"base": base,
|
||||
"rates": rates,
|
||||
"updated_at": "2026-05-08T00:00:00Z",
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from app.services.marketing import MarketingService
|
||||
from app.core.security import decode_token
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MarketingRequest(BaseModel):
|
||||
product_name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
price: Optional[str] = None
|
||||
keywords: Optional[list] = None
|
||||
target: str = "US importers"
|
||||
style: str = "professional"
|
||||
language: str = "en"
|
||||
count: int = 3
|
||||
|
||||
|
||||
class KeywordsRequest(BaseModel):
|
||||
product_name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
language: str = "en"
|
||||
count: int = 10
|
||||
|
||||
|
||||
class CompetitorRequest(BaseModel):
|
||||
product_name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
market: str = "US"
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_marketing(data: MarketingRequest, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
service = MarketingService()
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
"description": data.description,
|
||||
"category": data.category,
|
||||
"price": data.price,
|
||||
"keywords": data.keywords,
|
||||
}
|
||||
results = await service.generate(product_info, data.target, data.style, data.language, data.count)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"product": data.product_name,
|
||||
"target": data.target,
|
||||
"count": len(results),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/keywords")
|
||||
async def generate_keywords(data: KeywordsRequest, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
service = MarketingService()
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
"description": data.description,
|
||||
"category": data.category,
|
||||
}
|
||||
keywords = await service.generate_keywords(product_info, data.language, data.count)
|
||||
|
||||
return {"keywords": keywords, "product": data.product_name}
|
||||
|
||||
|
||||
@router.post("/competitor-analysis")
|
||||
async def competitor_analysis(data: CompetitorRequest, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
service = MarketingService()
|
||||
product_info = {
|
||||
"name": data.product_name,
|
||||
"description": data.description,
|
||||
"category": data.category,
|
||||
}
|
||||
analysis = await service.analyze_competitors(product_info, data.market)
|
||||
|
||||
return {"analysis": analysis, "product": data.product_name, "market": data.market}
|
||||
@@ -0,0 +1,101 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from app.database import get_db
|
||||
from app.services.product import ProductService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ProductCreate(BaseModel):
|
||||
name: str
|
||||
name_en: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
description_en: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
price: Optional[str] = None
|
||||
price_unit: Optional[str] = "USD"
|
||||
moq: Optional[str] = None
|
||||
keywords: Optional[list] = []
|
||||
specifications: Optional[dict] = {}
|
||||
images: Optional[list] = []
|
||||
|
||||
|
||||
class ProductUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
name_en: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
description_en: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
price: Optional[str] = None
|
||||
price_unit: Optional[str] = None
|
||||
moq: Optional[str] = None
|
||||
keywords: Optional[list] = None
|
||||
specifications: Optional[dict] = None
|
||||
images: Optional[list] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_products(
|
||||
category: Optional[str] = None,
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = ProductService(db)
|
||||
return await service.list_products(user_id, category, page, size)
|
||||
|
||||
|
||||
@router.get("/{product_id}")
|
||||
async def get_product(
|
||||
product_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = ProductService(db)
|
||||
product = await service.get_product(user_id, product_id)
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
return product
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_product(
|
||||
data: ProductCreate,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = ProductService(db)
|
||||
product = await service.create_product(user_id, data.dict())
|
||||
return product
|
||||
|
||||
|
||||
@router.patch("/{product_id}")
|
||||
async def update_product(
|
||||
product_id: str,
|
||||
data: ProductUpdate,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = ProductService(db)
|
||||
product = await service.update_product(user_id, product_id, data.dict(exclude_unset=True))
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
return product
|
||||
|
||||
|
||||
@router.delete("/{product_id}")
|
||||
async def delete_product(
|
||||
product_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = ProductService(db)
|
||||
deleted = await service.delete_product(user_id, product_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
return {"message": "Product deleted"}
|
||||
@@ -0,0 +1,147 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import decode_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class DeviceRegister(BaseModel):
|
||||
client_id: str
|
||||
platform: Optional[str] = None
|
||||
device_info: Optional[dict] = None
|
||||
|
||||
|
||||
class PushMessage(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
payload: Optional[dict] = None
|
||||
target_type: str = "all"
|
||||
target_value: Optional[str] = None
|
||||
|
||||
|
||||
class PushResponse(BaseModel):
|
||||
success: bool
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# 模拟存储的设备信息(实际应存数据库)
|
||||
devices_db = {}
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register_device(
|
||||
data: DeviceRegister,
|
||||
authorization: str = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if user_id not in devices_db:
|
||||
devices_db[user_id] = []
|
||||
|
||||
existing = [d for d in devices_db[user_id] if d.get("client_id") == data.client_id]
|
||||
if not existing:
|
||||
devices_db[user_id].append({
|
||||
"client_id": data.client_id,
|
||||
"platform": data.platform,
|
||||
"device_info": data.device_info,
|
||||
})
|
||||
|
||||
return {"success": True, "message": "Device registered"}
|
||||
|
||||
|
||||
@router.post("/send")
|
||||
async def send_push(
|
||||
message: PushMessage,
|
||||
authorization: str = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
|
||||
user_devices = devices_db.get(user_id, [])
|
||||
if not user_devices:
|
||||
return PushResponse(success=False, error="No devices registered")
|
||||
|
||||
# 实际项目中这里调用 uni-push/极光等API
|
||||
# 模拟返回成功
|
||||
message_id = f"msg_{user_id}_{int(payload.get('iat', 0))}"
|
||||
|
||||
print(f"Push message to user {user_id}: {message.title} - {message.content}")
|
||||
|
||||
return PushResponse(success=True, message_id=message_id)
|
||||
|
||||
|
||||
@router.post("/send-to-customer")
|
||||
async def send_to_customer(
|
||||
customer_id: str,
|
||||
title: str,
|
||||
content: str,
|
||||
payload: Optional[dict] = None,
|
||||
authorization: str = None,
|
||||
):
|
||||
"""
|
||||
针对特定客户的推送通知
|
||||
例如:客户沉默提醒、报价提醒等
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload_data = decode_token(authorization[7:])
|
||||
if not payload_data:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload_data.get("sub")
|
||||
|
||||
# 这里可以添加针对客户的特定逻辑
|
||||
notification = {
|
||||
"type": "customer_alert",
|
||||
"customer_id": customer_id,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"payload": payload or {}
|
||||
}
|
||||
|
||||
print(f"Customer notification for user {user_id}, customer {customer_id}: {title}")
|
||||
|
||||
return PushResponse(success=True, message_id=f"alert_{customer_id}")
|
||||
|
||||
|
||||
@router.get("/devices")
|
||||
async def list_devices(
|
||||
authorization: str = None,
|
||||
):
|
||||
"""列出用户已注册的设备"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
return {"error": "Unauthorized"}, 401
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
if not payload:
|
||||
return {"error": "Invalid token"}, 401
|
||||
|
||||
user_id = payload.get("sub")
|
||||
user_devices = devices_db.get(user_id, [])
|
||||
|
||||
return {
|
||||
"devices": user_devices,
|
||||
"count": len(user_devices)
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated, Optional
|
||||
from app.database import get_db
|
||||
from app.services.quotation import QuotationService
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_quotation(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
try:
|
||||
quotation = await service.create_quotation(user_id, data)
|
||||
return quotation
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_quotations(
|
||||
page: int = Query(1, ge=1),
|
||||
size: int = Query(20, ge=1, le=100),
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
return await service.list_quotations(user_id, page, size)
|
||||
|
||||
|
||||
@router.get("/{quotation_id}")
|
||||
async def get_quotation(
|
||||
quotation_id: str,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
quotation = await service.get_quotation(user_id, quotation_id)
|
||||
if not quotation:
|
||||
raise HTTPException(status_code=404, detail="Quotation not found")
|
||||
return quotation
|
||||
|
||||
|
||||
@router.patch("/{quotation_id}/status")
|
||||
async def update_quotation_status(
|
||||
quotation_id: str,
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
db: Annotated[AsyncSession, Depends(get_db)] = None,
|
||||
):
|
||||
service = QuotationService(db)
|
||||
quotation = await service.update_status(user_id, quotation_id, data.get("status", "draft"))
|
||||
if not quotation:
|
||||
raise HTTPException(status_code=404, detail="Quotation not found")
|
||||
return quotation
|
||||
@@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from app.services.translation import TranslationService
|
||||
from app.core.security import decode_token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TranslateRequest(BaseModel):
|
||||
text: str
|
||||
target_lang: str
|
||||
source_lang: Optional[str] = "auto"
|
||||
context: Optional[str] = None
|
||||
|
||||
|
||||
class ReplyRequest(BaseModel):
|
||||
inquiry: str
|
||||
tone: str = "professional"
|
||||
count: int = 3
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ExtractRequest(BaseModel):
|
||||
text: str
|
||||
extract_type: str = "auto"
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def translate_text(data: TranslateRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
payload = decode_token(authorization[7:])
|
||||
user_id = payload.get("sub") if payload else None
|
||||
|
||||
service = TranslationService()
|
||||
result = await service.translate(
|
||||
text=data.text,
|
||||
target_lang=data.target_lang,
|
||||
source_lang=data.source_lang,
|
||||
context=data.context,
|
||||
user_id=user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/reply")
|
||||
async def generate_reply(data: ReplyRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
service = TranslationService()
|
||||
results = await service.generate_reply(
|
||||
inquiry=data.inquiry,
|
||||
context=data.context,
|
||||
tone=data.tone,
|
||||
count=data.count,
|
||||
)
|
||||
return {"suggestions": results, "inquiry": data.inquiry, "count": len(results)}
|
||||
|
||||
|
||||
@router.post("/extract")
|
||||
async def extract_info(data: ExtractRequest, authorization: str = None):
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
service = TranslationService()
|
||||
result = await service.extract_info(data.text, data.extract_type)
|
||||
return {"extracted": result, "type": data.extract_type}
|
||||
|
||||
|
||||
@router.post("/feedback")
|
||||
async def feedback(data: dict, authorization: str = None):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing token")
|
||||
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
corpus = TradeCorpus()
|
||||
|
||||
entry_id = data.get("entry_id")
|
||||
rating = data.get("rating")
|
||||
if entry_id and rating:
|
||||
await corpus.rate_entry(entry_id, rating)
|
||||
|
||||
return {"status": "ok"}
|
||||
@@ -0,0 +1,62 @@
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Annotated
|
||||
from app.database import get_db
|
||||
from app.services.whatsapp import WhatsAppService
|
||||
from app.services.customer import CustomerService
|
||||
from app.services.translation import TranslationService
|
||||
from app.core.security import decode_token
|
||||
from app.api.v1.deps import get_current_user_id
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/webhook")
|
||||
async def verify_webhook(
|
||||
hub_mode: str = None,
|
||||
hub_verify_token: str = None,
|
||||
hub_challenge: str = None,
|
||||
):
|
||||
svc = WhatsAppService()
|
||||
result = svc.verify_webhook(hub_mode, hub_verify_token, hub_challenge)
|
||||
if result:
|
||||
return int(result)
|
||||
raise HTTPException(status_code=403, detail="Verification failed")
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
async def handle_webhook(request: Request, db: Annotated[AsyncSession, Depends(get_db)] = None):
|
||||
svc = WhatsAppService()
|
||||
body = await request.json()
|
||||
|
||||
msg_data = svc.parse_webhook(body)
|
||||
if not msg_data:
|
||||
return {"status": "ok"}
|
||||
|
||||
# TODO: Route to correct user based on WhatsApp number
|
||||
# For MVP, handle as generic incoming message
|
||||
return {"status": "ok", "message": "received"}
|
||||
|
||||
|
||||
@router.post("/send")
|
||||
async def send_message(
|
||||
data: dict,
|
||||
user_id: str = Depends(get_current_user_id),
|
||||
):
|
||||
text = data.get("text")
|
||||
to = data.get("to")
|
||||
if not text or not to:
|
||||
raise HTTPException(status_code=400, detail="text and to are required")
|
||||
|
||||
svc = WhatsAppService()
|
||||
sent = await svc.send_text(to, text)
|
||||
if not sent:
|
||||
raise HTTPException(status_code=500, detail="Failed to send WhatsApp message")
|
||||
|
||||
return {"status": "sent", "to": to}
|
||||
|
||||
|
||||
@router.get("/qr")
|
||||
async def get_qr():
|
||||
return {"message": "WhatsApp QR login not available via API. Use WhatsApp Cloud API instead."}
|
||||
@@ -0,0 +1,23 @@
|
||||
from celery import Celery
|
||||
from app.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"tradmate",
|
||||
broker=settings.CELERY_BROKER_URL,
|
||||
backend=settings.CELERY_RESULT_BACKEND,
|
||||
include=[
|
||||
"app.workers.tasks",
|
||||
],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
task_time_limit=300,
|
||||
worker_prefetch_multiplier=4,
|
||||
worker_max_tasks_per_child=1000,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
ENV_FILE = PROJECT_ROOT / ".env"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_file": str(ENV_FILE), "extra": "ignore"}
|
||||
|
||||
APP_NAME: str = "TradeMate"
|
||||
|
||||
SECRET_KEY: str
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
DATABASE_URL: str
|
||||
DB_ECHO: bool = False
|
||||
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/1"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2"
|
||||
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
DEEPL_API_KEY: Optional[str] = None
|
||||
|
||||
LOCAL_MODEL_ENABLED: bool = False
|
||||
LOCAL_MODEL_URL: str = "http://localhost:8001"
|
||||
|
||||
WHATSAPP_API_TOKEN: Optional[str] = None
|
||||
WHATSAPP_PHONE_NUMBER_ID: Optional[str] = None
|
||||
WHATSAPP_WEBHOOK_VERIFY_TOKEN: Optional[str] = None
|
||||
|
||||
WECHAT_APP_ID: Optional[str] = None
|
||||
WECHAT_APP_SECRET: Optional[str] = None
|
||||
|
||||
EXCHANGE_RATE_API_KEY: Optional[str] = None
|
||||
|
||||
UPLOAD_DIR: str = "./uploads"
|
||||
MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024
|
||||
|
||||
FRONTEND_URL: str = "http://localhost:3000"
|
||||
BACKEND_URL: str = "http://localhost:8000"
|
||||
|
||||
AI_ROUTING: dict = {
|
||||
"translate": {"primary": "deepl", "fallback": ["openai", "local"]},
|
||||
"reply": {"primary": "openai", "fallback": ["anthropic", "local"]},
|
||||
"marketing": {"primary": "anthropic", "fallback": ["openai", "local"]},
|
||||
"extract": {"primary": "openai", "fallback": ["anthropic"]},
|
||||
"quotation": {"primary": "openai", "fallback": ["anthropic"]},
|
||||
}
|
||||
|
||||
FREE_DAILY_TRANSLATE_CHARS: int = 5000
|
||||
FREE_DAILY_REPLIES: int = 20
|
||||
FREE_DAILY_MARKETING: int = 5
|
||||
FREE_MAX_CUSTOMERS: int = 5
|
||||
FREE_MAX_PRODUCTS: int = 1
|
||||
FREE_DAILY_QUOTATIONS: int = 3
|
||||
|
||||
PRO_DAILY_TRANSLATE_CHARS: int = 50000
|
||||
PRO_DAILY_REPLIES: int = 200
|
||||
PRO_DAILY_MARKETING: int = 50
|
||||
PRO_MAX_CUSTOMERS: int = 100
|
||||
PRO_MAX_PRODUCTS: int = 20
|
||||
PRO_DAILY_QUOTATIONS: int = 30
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -0,0 +1,58 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class TradeMateException(Exception):
|
||||
def __init__(self, code: int, message: str, detail: str = None):
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class NotFoundError(TradeMateException):
|
||||
def __init__(self, resource: str = "Resource"):
|
||||
super().__init__(404, f"{resource} not found")
|
||||
|
||||
|
||||
class UnauthorizedError(TradeMateException):
|
||||
def __init__(self, detail: str = "Authentication required"):
|
||||
super().__init__(401, "Unauthorized", detail)
|
||||
|
||||
|
||||
class ForbiddenError(TradeMateException):
|
||||
def __init__(self, detail: str = "Insufficient permissions"):
|
||||
super().__init__(403, "Forbidden", detail)
|
||||
|
||||
|
||||
class QuotaExceededError(TradeMateException):
|
||||
def __init__(self, feature: str):
|
||||
super().__init__(429, "Quota exceeded", f"Daily limit reached for {feature}. Upgrade to Pro for more.")
|
||||
|
||||
|
||||
class TierRestrictionError(TradeMateException):
|
||||
def __init__(self, feature: str, required_tier: str):
|
||||
super().__init__(
|
||||
402,
|
||||
"Upgrade required",
|
||||
f"{feature} requires {required_tier} plan",
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI):
|
||||
@app.exception_handler(TradeMateException)
|
||||
async def handle_tradmate_exception(request: Request, exc: TradeMateException):
|
||||
return JSONResponse(
|
||||
status_code=exc.code,
|
||||
content={
|
||||
"error": exc.message,
|
||||
"detail": exc.detail,
|
||||
"code": exc.code,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def handle_generic_exception(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error", "detail": str(exc) if app.debug else "An unexpected error occurred"},
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config import settings
|
||||
from app.core.security import decode_token
|
||||
import redis.asyncio as aioredis
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_user_tier_from_token(request: Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
return "anonymous"
|
||||
payload = decode_token(auth[7:])
|
||||
if not payload:
|
||||
return "anonymous"
|
||||
request.state.user_id = payload.get("sub")
|
||||
request.state.user_tier = payload.get("tier", "free")
|
||||
return request.state.user_tier
|
||||
|
||||
|
||||
class TierMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.url.path.startswith("/api/v1"):
|
||||
tier = get_user_tier_from_token(request)
|
||||
tier_config = {
|
||||
"free": {
|
||||
"max_products": settings.FREE_MAX_PRODUCTS,
|
||||
"max_customers": settings.FREE_MAX_CUSTOMERS,
|
||||
},
|
||||
"pro": {
|
||||
"max_products": settings.PRO_MAX_PRODUCTS,
|
||||
"max_customers": settings.PRO_MAX_CUSTOMERS,
|
||||
},
|
||||
"enterprise": {
|
||||
"max_products": 9999,
|
||||
"max_customers": 99999,
|
||||
},
|
||||
}
|
||||
request.state.tier_config = tier_config.get(tier, tier_config["free"])
|
||||
else:
|
||||
request.state.user_id = None
|
||||
request.state.user_tier = "anonymous"
|
||||
request.state.tier_config = {}
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
class QuotaMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.url.path.startswith("/api/v1"):
|
||||
return await call_next(request)
|
||||
|
||||
if request.state.user_tier in ("anonymous",):
|
||||
return await call_next(request)
|
||||
|
||||
user_id = request.state.user_id
|
||||
tier = request.state.user_tier
|
||||
|
||||
if tier == "enterprise":
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
if method == "GET":
|
||||
return await call_next(request)
|
||||
|
||||
quota_map = {
|
||||
"/api/v1/translate": {
|
||||
"free": settings.FREE_DAILY_TRANSLATE_CHARS,
|
||||
"pro": settings.PRO_DAILY_TRANSLATE_CHARS,
|
||||
},
|
||||
"/api/v1/translate/reply": {
|
||||
"free": settings.FREE_DAILY_REPLIES,
|
||||
"pro": settings.PRO_DAILY_REPLIES,
|
||||
},
|
||||
"/api/v1/marketing": {
|
||||
"free": settings.FREE_DAILY_MARKETING,
|
||||
"pro": settings.PRO_DAILY_MARKETING,
|
||||
},
|
||||
"/api/v1/quotations": {
|
||||
"free": settings.FREE_DAILY_QUOTATIONS,
|
||||
"pro": settings.PRO_DAILY_QUOTATIONS,
|
||||
},
|
||||
}
|
||||
|
||||
matched_key = None
|
||||
for prefix, limits in quota_map.items():
|
||||
if path.startswith(prefix):
|
||||
matched_key = prefix
|
||||
break
|
||||
|
||||
if not matched_key:
|
||||
return await call_next(request)
|
||||
|
||||
limit = quota_map[matched_key].get(tier)
|
||||
if limit is None:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
r = aioredis.from_url(settings.REDIS_URL)
|
||||
key = f"quota:{user_id}:{matched_key}:{datetime.utcnow().strftime('%Y%m%d')}"
|
||||
current = await r.incr(key)
|
||||
await r.expire(key, 86400)
|
||||
if current > limit:
|
||||
from app.core.exceptions import QuotaExceededError
|
||||
raise QuotaExceededError(matched_key)
|
||||
request.state.quota_remaining = limit - current
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Quota check failed: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (
|
||||
expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
)
|
||||
to_encode.update({"exp": expire, "type": "access"})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
except JWTError:
|
||||
return None
|
||||
@@ -0,0 +1,33 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from app.config import settings
|
||||
|
||||
async_engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DB_ECHO,
|
||||
pool_size=20,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -0,0 +1,53 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.config import settings
|
||||
from app.core.exceptions import register_exception_handlers
|
||||
from app.core.middleware import TierMiddleware, QuotaMiddleware
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[settings.FRONTEND_URL, "*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(TierMiddleware)
|
||||
app.add_middleware(QuotaMiddleware)
|
||||
|
||||
register_exception_handlers(app)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "app": settings.APP_NAME, "version": "1.0.0"}
|
||||
|
||||
|
||||
from app.api.v1 import auth, marketing, translate, customer, quotation, whatsapp, product, exchange, push
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
|
||||
app.include_router(marketing.router, prefix="/api/v1/marketing", tags=["marketing"])
|
||||
app.include_router(translate.router, prefix="/api/v1/translate", tags=["translate"])
|
||||
app.include_router(customer.router, prefix="/api/v1/customers", tags=["customers"])
|
||||
app.include_router(quotation.router, prefix="/api/v1/quotations", tags=["quotations"])
|
||||
app.include_router(whatsapp.router, prefix="/api/v1/whatsapp", tags=["whatsapp"])
|
||||
app.include_router(product.router, prefix="/api/v1/products", tags=["products"])
|
||||
app.include_router(exchange.router, prefix="/api/v1/exchange", tags=["exchange"])
|
||||
app.include_router(push.router, prefix="/api/v1/push", tags=["push"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
@@ -0,0 +1,11 @@
|
||||
from .user import User, Product
|
||||
from .customer import Customer, Conversation, Message
|
||||
from .quotation import Quotation, QuotationItem
|
||||
from .corpus import CorpusEntry
|
||||
|
||||
__all__ = [
|
||||
"User", "Product",
|
||||
"Customer", "Conversation", "Message",
|
||||
"Quotation", "QuotationItem",
|
||||
"CorpusEntry",
|
||||
]
|
||||
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Text, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class CorpusEntry(Base):
|
||||
__tablename__ = "corpus_entries"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
source_text = Column(Text, nullable=False)
|
||||
target_text = Column(Text, nullable=False)
|
||||
source_lang = Column(String(20))
|
||||
target_lang = Column(String(20))
|
||||
task_type = Column(String(50), nullable=False)
|
||||
domain = Column(String(100), default="general")
|
||||
provider_used = Column(String(50))
|
||||
quality_score = Column(Float, default=0.5)
|
||||
user_edited = Column(Boolean, default=False)
|
||||
user_rating = Column(Integer)
|
||||
usage_count = Column(Integer, default=0)
|
||||
embedding = Column(Vector(768))
|
||||
metadata = Column(JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -0,0 +1,72 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Customer(Base):
|
||||
__tablename__ = "customers"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
company = Column(String(255))
|
||||
country = Column(String(100))
|
||||
phone = Column(String(50))
|
||||
email = Column(String(255))
|
||||
whatsapp_id = Column(String(255))
|
||||
source = Column(String(100))
|
||||
tags = Column(JSONB, default=[])
|
||||
notes = Column(Text)
|
||||
preference = Column(JSONB, default={})
|
||||
status = Column(String(50), default="lead")
|
||||
last_contact_at = Column(DateTime)
|
||||
silence_started_at = Column(DateTime)
|
||||
next_followup_at = Column(DateTime)
|
||||
estimated_value = Column(String(50))
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="customers")
|
||||
conversations = relationship("Conversation", back_populates="customer", cascade="all, delete-orphan")
|
||||
quotations = relationship("Quotation", back_populates="customer")
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
customer_id = Column(UUID(as_uuid=True), ForeignKey("customers.id"), nullable=False, index=True)
|
||||
channel = Column(String(50), default="whatsapp")
|
||||
topic = Column(String(255))
|
||||
status = Column(String(50), default="active")
|
||||
message_count = Column(Integer, default=0)
|
||||
last_message_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="conversations")
|
||||
customer = relationship("Customer", back_populates="conversations")
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan", order_by="Message.created_at")
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False, index=True)
|
||||
direction = Column(String(20), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
content_translated = Column(Text)
|
||||
content_type = Column(String(50), default="text")
|
||||
ai_suggestions = Column(JSONB)
|
||||
selected_suggestion = Column(Integer)
|
||||
user_edited = Column(Text)
|
||||
status = Column(String(50), default="sent")
|
||||
metadata = Column(JSONB, default={})
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
@@ -0,0 +1,50 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text, ForeignKey, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class Quotation(Base):
|
||||
__tablename__ = "quotations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
customer_id = Column(UUID(as_uuid=True), ForeignKey("customers.id"), nullable=False)
|
||||
title = Column(String(255))
|
||||
status = Column(String(50), default="draft")
|
||||
currency = Column(String(10), default="USD")
|
||||
exchange_rate = Column(Float)
|
||||
payment_terms = Column(String(255))
|
||||
delivery_terms = Column(String(255))
|
||||
lead_time = Column(String(100))
|
||||
valid_until = Column(String(100))
|
||||
subtotal = Column(Float)
|
||||
discount = Column(Float, default=0)
|
||||
shipping = Column(Float, default=0)
|
||||
total = Column(Float)
|
||||
notes = Column(Text)
|
||||
pdf_url = Column(Text)
|
||||
sent_at = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="quotations")
|
||||
customer = relationship("Customer", back_populates="quotations")
|
||||
items = relationship("QuotationItem", back_populates="quotation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class QuotationItem(Base):
|
||||
__tablename__ = "quotation_items"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
quotation_id = Column(UUID(as_uuid=True), ForeignKey("quotations.id"), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False)
|
||||
description = Column(Text)
|
||||
quantity = Column(Integer, nullable=False)
|
||||
unit_price = Column(Float, nullable=False)
|
||||
total_price = Column(Float)
|
||||
unit = Column(String(50), default="pcs")
|
||||
|
||||
quotation = relationship("Quotation", back_populates="items")
|
||||
@@ -0,0 +1,54 @@
|
||||
from sqlalchemy import Column, String, Boolean, Integer, DateTime, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.database import Base
|
||||
import uuid
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
wechat_openid = Column(String(255), unique=True, index=True)
|
||||
phone = Column(String(20), unique=True, index=True)
|
||||
username = Column(String(100))
|
||||
password_hash = Column(String(255))
|
||||
tier = Column(String(50), default="free")
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
settings = Column(JSONB, default={
|
||||
"preferred_translate_provider": "auto",
|
||||
"reply_tone": "professional",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"languages": ["zh", "en"],
|
||||
})
|
||||
|
||||
products = relationship("Product", back_populates="user", cascade="all, delete-orphan")
|
||||
customers = relationship("Customer", back_populates="user", cascade="all, delete-orphan")
|
||||
conversations = relationship("Conversation", back_populates="user", cascade="all, delete-orphan")
|
||||
quotations = relationship("Quotation", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = "products"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
name_en = Column(String(255))
|
||||
description = Column(Text)
|
||||
description_en = Column(Text)
|
||||
category = Column(String(100))
|
||||
price = Column(String(50))
|
||||
price_unit = Column(String(20), default="USD")
|
||||
moq = Column(String(50))
|
||||
keywords = Column(JSONB, default=[])
|
||||
specifications = Column(JSONB, default={})
|
||||
images = Column(JSONB, default=[])
|
||||
is_active = Column(Boolean, default=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="products")
|
||||
@@ -0,0 +1,204 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.customer import Customer, Conversation, Message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomerService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def list_customers(self, user_id: str, status: Optional[str] = None, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(Customer).where(Customer.user_id == user_id)
|
||||
count_query = select(func.count()).select_from(Customer).where(Customer.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.where(Customer.status == status)
|
||||
count_query = count_query.where(Customer.status == status)
|
||||
|
||||
query = query.order_by(Customer.updated_at.desc()).offset((page - 1) * size).limit(size)
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
customers = result.scalars().all()
|
||||
|
||||
return {
|
||||
"items": [self._to_dict(c) for c in customers],
|
||||
"total": total.scalar(),
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def get_customer(self, user_id: str, customer_id: str) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
c = result.scalar_one_or_none()
|
||||
return self._to_dict(c) if c else None
|
||||
|
||||
async def create_customer(self, user_id: str, data: Dict[str, Any]) -> Dict:
|
||||
c = Customer(user_id=user_id, **data)
|
||||
self.db.add(c)
|
||||
await self.db.flush()
|
||||
return self._to_dict(c)
|
||||
|
||||
async def update_customer(self, user_id: str, customer_id: str, data: Dict[str, Any]) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
c = result.scalar_one_or_none()
|
||||
if not c:
|
||||
return None
|
||||
for k, v in data.items():
|
||||
if hasattr(c, k):
|
||||
setattr(c, k, v)
|
||||
await self.db.flush()
|
||||
return self._to_dict(c)
|
||||
|
||||
async def delete_customer(self, user_id: str, customer_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
c = result.scalar_one_or_none()
|
||||
if not c:
|
||||
return False
|
||||
await self.db.delete(c)
|
||||
return True
|
||||
|
||||
async def get_silent_customers(self, user_id: str, days: int = 3) -> List[Dict]:
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
result = await self.db.execute(
|
||||
select(Customer)
|
||||
.where(
|
||||
and_(
|
||||
Customer.user_id == user_id,
|
||||
Customer.status.in_(["lead", "negotiating"]),
|
||||
Customer.last_contact_at.isnot(None),
|
||||
Customer.last_contact_at < cutoff,
|
||||
)
|
||||
)
|
||||
.order_by(Customer.last_contact_at.asc())
|
||||
)
|
||||
return [self._to_dict(c) for c in result.scalars().all()]
|
||||
|
||||
async def record_contact(self, user_id: str, customer_id: str):
|
||||
now = datetime.utcnow()
|
||||
result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
c = result.scalar_one_or_none()
|
||||
if c:
|
||||
c.last_contact_at = now
|
||||
c.silence_started_at = None
|
||||
c.next_followup_at = now + timedelta(days=3)
|
||||
await self.db.flush()
|
||||
|
||||
async def get_conversation(self, user_id: str, customer_id: str, page: int = 1, size: int = 50) -> Dict[str, Any]:
|
||||
conv_query = select(Conversation).where(
|
||||
and_(Conversation.user_id == user_id, Conversation.customer_id == customer_id)
|
||||
).order_by(Conversation.created_at.desc()).limit(1)
|
||||
|
||||
conv_result = await self.db.execute(conv_query)
|
||||
conv = conv_result.scalar_one_or_none()
|
||||
if not conv:
|
||||
return {"messages": [], "total": 0, "conversation_id": None}
|
||||
|
||||
msg_query = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.offset((page - 1) * size)
|
||||
.limit(size)
|
||||
)
|
||||
msg_result = await self.db.execute(msg_query)
|
||||
messages = msg_result.scalars().all()
|
||||
|
||||
return {
|
||||
"conversation_id": str(conv.id),
|
||||
"messages": [
|
||||
{
|
||||
"id": str(m.id),
|
||||
"direction": m.direction,
|
||||
"content": m.content,
|
||||
"content_translated": m.content_translated,
|
||||
"ai_suggestions": m.ai_suggestions,
|
||||
"selected_suggestion": m.selected_suggestion,
|
||||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||||
}
|
||||
for m in messages
|
||||
],
|
||||
"total": conv.message_count,
|
||||
}
|
||||
|
||||
async def save_message(
|
||||
self, user_id: str, customer_id: str, direction: str, content: str,
|
||||
translation: Optional[str] = None, suggestions: Optional[List] = None,
|
||||
) -> Dict:
|
||||
conv_result = await self.db.execute(
|
||||
select(Conversation).where(
|
||||
and_(Conversation.user_id == user_id, Conversation.customer_id == customer_id)
|
||||
).order_by(Conversation.created_at.desc()).limit(1)
|
||||
)
|
||||
conv = conv_result.scalar_one_or_none()
|
||||
|
||||
if not conv:
|
||||
conv = Conversation(
|
||||
user_id=user_id,
|
||||
customer_id=customer_id,
|
||||
channel="whatsapp",
|
||||
status="active",
|
||||
)
|
||||
self.db.add(conv)
|
||||
await self.db.flush()
|
||||
|
||||
msg = Message(
|
||||
conversation_id=conv.id,
|
||||
direction=direction,
|
||||
content=content,
|
||||
content_translated=translation,
|
||||
ai_suggestions=suggestions,
|
||||
)
|
||||
self.db.add(msg)
|
||||
conv.message_count = (conv.message_count or 0) + 1
|
||||
conv.last_message_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
|
||||
await self.record_contact(user_id, customer_id)
|
||||
|
||||
return {
|
||||
"message_id": str(msg.id),
|
||||
"conversation_id": str(conv.id),
|
||||
"direction": direction,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def _to_dict(self, c: Customer) -> Dict:
|
||||
if not c:
|
||||
return {}
|
||||
return {
|
||||
"id": str(c.id),
|
||||
"name": c.name,
|
||||
"company": c.company,
|
||||
"country": c.country,
|
||||
"phone": c.phone,
|
||||
"email": c.email,
|
||||
"whatsapp_id": c.whatsapp_id,
|
||||
"source": c.source,
|
||||
"tags": c.tags,
|
||||
"status": c.status,
|
||||
"last_contact_at": c.last_contact_at.isoformat() if c.last_contact_at else None,
|
||||
"silence_days": (datetime.utcnow() - c.last_contact_at).days if c.last_contact_at else 0,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.router import get_ai_router
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketingService:
|
||||
def __init__(self):
|
||||
self.ai = get_ai_router()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
product_info: Dict[str, Any],
|
||||
target: str,
|
||||
style: str = "professional",
|
||||
language: str = "en",
|
||||
count: int = 3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
results = []
|
||||
styles = self._get_style_variants(style, count)
|
||||
|
||||
for s in styles:
|
||||
try:
|
||||
result = await self.ai.marketing(product_info, target, s, language)
|
||||
results.append({
|
||||
"content": result.get("content", ""),
|
||||
"style": s,
|
||||
"provider": result.get("provider_used", "unknown"),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Marketing generation failed for style '{s}': {e}")
|
||||
results.append({"content": "", "style": s, "error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
async def generate_keywords(
|
||||
self, product_info: Dict[str, Any], language: str = "en", count: int = 10
|
||||
) -> List[str]:
|
||||
try:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
text = f"Product: {product_info.get('name', '')}. {product_info.get('description', '')}"
|
||||
result = await self.ai.extract(text, schema)
|
||||
keywords = result.get("data", {}).get("keywords", [])
|
||||
return keywords[:count]
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword generation failed: {e}")
|
||||
return []
|
||||
|
||||
def _get_style_variants(self, base_style: str, count: int) -> List[str]:
|
||||
all_styles = ["professional", "friendly", "urgent", "benefit_focused", "storytelling"]
|
||||
if base_style in all_styles:
|
||||
all_styles.remove(base_style)
|
||||
all_styles.insert(0, base_style)
|
||||
return all_styles[:count]
|
||||
|
||||
async def analyze_competitors(
|
||||
self, product_info: Dict[str, Any], market: str = "US"
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
text = f"Product: {product_info.get('name', '')} in {market} market. Category: {product_info.get('category', '')}. Description: {product_info.get('description', '')}"
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"price_range": {"type": "string"},
|
||||
"key_selling_points": {"type": "array", "items": {"type": "string"}},
|
||||
"common_keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"market_trends": {"type": "string"},
|
||||
"suggestions": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
result = await self.ai.extract(text, schema)
|
||||
return result.get("data", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Competitor analysis failed: {e}")
|
||||
return {}
|
||||
@@ -0,0 +1,100 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from app.models.user import Product
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProductService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def list_products(self, user_id: str, category: Optional[str] = None, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
query = select(Product).where(Product.user_id == user_id, Product.is_active == True)
|
||||
count_query = select(func.count()).select_from(Product).where(Product.user_id == user_id, Product.is_active == True)
|
||||
|
||||
if category:
|
||||
query = query.where(Product.category == category)
|
||||
count_query = count_query.where(Product.category == category)
|
||||
|
||||
query = query.order_by(Product.updated_at.desc()).offset((page - 1) * size).limit(size)
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
products = result.scalars().all()
|
||||
|
||||
return {
|
||||
"items": [self._to_dict(p) for p in products],
|
||||
"total": total.scalar(),
|
||||
"page": page,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
async def get_product(self, user_id: str, product_id: str) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Product).where(
|
||||
and_(Product.id == product_id, Product.user_id == user_id)
|
||||
)
|
||||
)
|
||||
p = result.scalar_one_or_none()
|
||||
return self._to_dict(p) if p else None
|
||||
|
||||
async def create_product(self, user_id: str, data: Dict[str, Any]) -> Dict:
|
||||
p = Product(user_id=user_id, **data)
|
||||
self.db.add(p)
|
||||
await self.db.flush()
|
||||
return self._to_dict(p)
|
||||
|
||||
async def update_product(self, user_id: str, product_id: str, data: Dict[str, Any]) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Product).where(
|
||||
and_(Product.id == product_id, Product.user_id == user_id)
|
||||
)
|
||||
)
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
return None
|
||||
|
||||
for k, v in data.items():
|
||||
if v is not None and hasattr(p, k):
|
||||
setattr(p, k, v)
|
||||
|
||||
await self.db.flush()
|
||||
return self._to_dict(p)
|
||||
|
||||
async def delete_product(self, user_id: str, product_id: str) -> bool:
|
||||
result = await self.db.execute(
|
||||
select(Product).where(
|
||||
and_(Product.id == product_id, Product.user_id == user_id)
|
||||
)
|
||||
)
|
||||
p = result.scalar_one_or_none()
|
||||
if not p:
|
||||
return False
|
||||
p.is_active = False
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
def _to_dict(self, p: Product) -> Dict:
|
||||
if not p:
|
||||
return {}
|
||||
return {
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"name_en": p.name_en,
|
||||
"description": p.description,
|
||||
"description_en": p.description_en,
|
||||
"category": p.category,
|
||||
"price": p.price,
|
||||
"price_unit": p.price_unit,
|
||||
"moq": p.moq,
|
||||
"keywords": p.keywords or [],
|
||||
"specifications": p.specifications or {},
|
||||
"images": p.images or [],
|
||||
"is_active": p.is_active,
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
"updated_at": p.updated_at.isoformat() if p.updated_at else None,
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
from app.models.quotation import Quotation, QuotationItem
|
||||
from app.models.customer import Customer
|
||||
from app.models.user import Product
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuotationService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_quotation(self, user_id: str, data: Dict[str, Any]) -> Dict:
|
||||
items_data = data.pop("items", [])
|
||||
|
||||
if data.get("customer_id"):
|
||||
cust_result = await self.db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == data["customer_id"], Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
if not cust_result.scalar_one_or_none():
|
||||
raise ValueError("Customer not found")
|
||||
|
||||
q = Quotation(user_id=user_id, **data)
|
||||
self.db.add(q)
|
||||
await self.db.flush()
|
||||
|
||||
total = 0
|
||||
for item_data in items_data:
|
||||
item_total = item_data.get("quantity", 0) * item_data.get("unit_price", 0)
|
||||
item = QuotationItem(
|
||||
quotation_id=q.id,
|
||||
product_name=item_data["product_name"],
|
||||
description=item_data.get("description"),
|
||||
quantity=item_data["quantity"],
|
||||
unit_price=item_data["unit_price"],
|
||||
total_price=item_total,
|
||||
unit=item_data.get("unit", "pcs"),
|
||||
)
|
||||
self.db.add(item)
|
||||
total += item_total
|
||||
|
||||
q.subtotal = total
|
||||
q.total = total - (data.get("discount", 0)) + (data.get("shipping", 0))
|
||||
await self.db.flush()
|
||||
|
||||
return await self._to_dict(q)
|
||||
|
||||
async def get_quotation(self, user_id: str, quotation_id: str) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Quotation).where(
|
||||
and_(Quotation.id == quotation_id, Quotation.user_id == user_id)
|
||||
)
|
||||
)
|
||||
q = result.scalar_one_or_none()
|
||||
return await self._to_dict(q) if q else None
|
||||
|
||||
async def list_quotations(self, user_id: str, page: int = 1, size: int = 20) -> Dict[str, Any]:
|
||||
from sqlalchemy import func
|
||||
query = select(Quotation).where(Quotation.user_id == user_id).order_by(Quotation.created_at.desc()).offset((page - 1) * size).limit(size)
|
||||
count_query = select(func.count()).select_from(Quotation).where(Quotation.user_id == user_id)
|
||||
|
||||
total = await self.db.execute(count_query)
|
||||
result = await self.db.execute(query)
|
||||
quotations = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for q in quotations:
|
||||
items.append(await self._to_dict(q))
|
||||
|
||||
return {"items": items, "total": total.scalar(), "page": page, "size": size}
|
||||
|
||||
async def update_status(self, user_id: str, quotation_id: str, status: str) -> Optional[Dict]:
|
||||
result = await self.db.execute(
|
||||
select(Quotation).where(
|
||||
and_(Quotation.id == quotation_id, Quotation.user_id == user_id)
|
||||
)
|
||||
)
|
||||
q = result.scalar_one_or_none()
|
||||
if not q:
|
||||
return None
|
||||
q.status = status
|
||||
if status == "sent":
|
||||
q.sent_at = datetime.utcnow()
|
||||
await self.db.flush()
|
||||
return await self._to_dict(q)
|
||||
|
||||
async def generate_quotation_text(self, q: Quotation) -> str:
|
||||
items_result = await self.db.execute(
|
||||
select(QuotationItem).where(QuotationItem.quotation_id == q.id)
|
||||
)
|
||||
items = items_result.scalars().all()
|
||||
|
||||
lines = [f"QUOTATION", f"", f"Date: {datetime.utcnow().strftime('%Y-%m-%d')}"]
|
||||
if q.valid_until:
|
||||
lines.append(f"Valid until: {q.valid_until}")
|
||||
lines.append(f"")
|
||||
lines.append(f"{'Item':<30} {'Qty':<10} {'Unit Price':<15} {'Total':<15}")
|
||||
lines.append("-" * 70)
|
||||
|
||||
for item in items:
|
||||
lines.append(f"{item.product_name:<30} {item.quantity:<10} ${item.unit_price:<12.2f} ${item.total_price:<10.2f}")
|
||||
|
||||
lines.append("-" * 70)
|
||||
if q.subtotal:
|
||||
lines.append(f"{'Subtotal':>55} ${q.subtotal:<10.2f}")
|
||||
if q.discount:
|
||||
lines.append(f"{'Discount':>55} -${q.discount:<9.2f}")
|
||||
if q.shipping:
|
||||
lines.append(f"{'Shipping':>55} ${q.shipping:<10.2f}")
|
||||
lines.append(f"{'TOTAL':>55} ${q.total or q.subtotal or 0:<10.2f}")
|
||||
lines.append(f"")
|
||||
if q.payment_terms:
|
||||
lines.append(f"Payment: {q.payment_terms}")
|
||||
if q.delivery_terms:
|
||||
lines.append(f"Delivery: {q.delivery_terms}")
|
||||
if q.lead_time:
|
||||
lines.append(f"Lead time: {q.lead_time}")
|
||||
if q.notes:
|
||||
lines.append(f"")
|
||||
lines.append(f"Notes: {q.notes}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _to_dict(self, q: Quotation) -> Dict:
|
||||
items_result = await self.db.execute(
|
||||
select(QuotationItem).where(QuotationItem.quotation_id == q.id)
|
||||
)
|
||||
items = items_result.scalars().all()
|
||||
|
||||
return {
|
||||
"id": str(q.id),
|
||||
"customer_id": str(q.customer_id) if q.customer_id else None,
|
||||
"title": q.title,
|
||||
"status": q.status,
|
||||
"currency": q.currency,
|
||||
"exchange_rate": q.exchange_rate,
|
||||
"payment_terms": q.payment_terms,
|
||||
"delivery_terms": q.delivery_terms,
|
||||
"lead_time": q.lead_time,
|
||||
"valid_until": q.valid_until,
|
||||
"subtotal": q.subtotal,
|
||||
"discount": q.discount,
|
||||
"shipping": q.shipping,
|
||||
"total": q.total,
|
||||
"notes": q.notes,
|
||||
"items": [
|
||||
{
|
||||
"product_name": i.product_name,
|
||||
"description": i.description,
|
||||
"quantity": i.quantity,
|
||||
"unit_price": i.unit_price,
|
||||
"total_price": i.total_price,
|
||||
"unit": i.unit,
|
||||
}
|
||||
for i in items
|
||||
],
|
||||
"text": await self.generate_quotation_text(q),
|
||||
"sent_at": q.sent_at.isoformat() if q.sent_at else None,
|
||||
"created_at": q.created_at.isoformat() if q.created_at else None,
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.ai.router import get_ai_router
|
||||
from app.ai.trade_corpus import TradeCorpus
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationService:
|
||||
def __init__(self):
|
||||
self.ai = get_ai_router()
|
||||
self.corpus = TradeCorpus()
|
||||
|
||||
async def translate(
|
||||
self, text: str, target_lang: str, source_lang: Optional[str] = None,
|
||||
context: Optional[str] = None, user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
similar = await self.corpus.find_similar(text, "translate")
|
||||
if similar:
|
||||
best = similar[0]
|
||||
if len(best["source"]) > 20 and self._similarity_ratio(text, best["source"]) > 0.85:
|
||||
return {
|
||||
"translated_text": best["target"],
|
||||
"source_lang": source_lang or "auto",
|
||||
"provider_used": "corpus_cache",
|
||||
"from_cache": True,
|
||||
}
|
||||
|
||||
result = await self.ai.translate(text, target_lang, source_lang, context)
|
||||
translated = result.get("translated_text", "")
|
||||
provider = result.get("provider_used", "unknown")
|
||||
|
||||
await self.corpus.record(
|
||||
source_text=text,
|
||||
target_text=translated,
|
||||
task_type="translate",
|
||||
provider=provider,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
metadata={"user_id": user_id} if user_id else None,
|
||||
)
|
||||
|
||||
result["source_lang"] = result.get("detected_source_lang", source_lang or "auto")
|
||||
result["from_cache"] = False
|
||||
return result
|
||||
|
||||
async def generate_reply(
|
||||
self, inquiry: str, context: Optional[Dict[str, Any]] = None,
|
||||
tone: str = "professional", count: int = 3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
similar = await self.corpus.find_similar(inquiry, "reply")
|
||||
if similar and count > 1:
|
||||
pass
|
||||
|
||||
results = []
|
||||
tones = self._get_tones(tone, count)
|
||||
|
||||
for t in tones:
|
||||
try:
|
||||
result = await self.ai.reply(inquiry, context, t)
|
||||
results.append({
|
||||
"reply": result.get("reply", ""),
|
||||
"tone": t,
|
||||
"provider": result.get("provider_used", "unknown"),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Reply generation failed for tone '{t}': {e}")
|
||||
results.append({"reply": "", "tone": t, "error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
async def extract_info(self, text: str, extract_type: str = "auto") -> Dict[str, Any]:
|
||||
schemas = {
|
||||
"product": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_name": {"type": "string"},
|
||||
"quantity": {"type": "string"},
|
||||
"price": {"type": "string"},
|
||||
"currency": {"type": "string"},
|
||||
"delivery_terms": {"type": "string"},
|
||||
"target_country": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"inquiry": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"intent": {"type": "string"},
|
||||
"product_interest": {"type": "string"},
|
||||
"quantity": {"type": "string"},
|
||||
"budget": {"type": "string"},
|
||||
"urgency": {"type": "string"},
|
||||
"contact_info": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema = schemas.get(extract_type, schemas["inquiry"])
|
||||
result = await self.ai.extract(text, schema)
|
||||
return result.get("data", {})
|
||||
|
||||
def _get_tones(self, base: str, count: int) -> List[str]:
|
||||
tones = ["professional", "friendly", "formal"]
|
||||
if base in tones:
|
||||
tones.remove(base)
|
||||
tones.insert(0, base)
|
||||
return tones[:count]
|
||||
|
||||
def _similarity_ratio(self, a: str, b: str) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
set_a, set_b = set(a.lower().split()), set(b.lower().split())
|
||||
if not set_a or not set_b:
|
||||
return 0.0
|
||||
return len(set_a & set_b) / len(set_a | set_b)
|
||||
@@ -0,0 +1,109 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import httpx
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhatsAppService:
|
||||
def __init__(self):
|
||||
self.api_token = settings.WHATSAPP_API_TOKEN
|
||||
self.phone_number_id = settings.WHATSAPP_PHONE_NUMBER_ID
|
||||
self.api_base = f"https://graph.facebook.com/v18.0/{self.phone_number_id}"
|
||||
|
||||
def verify_webhook(self, mode: str, token: str, challenge: str) -> Optional[str]:
|
||||
if mode == "subscribe" and token == settings.WHATSAPP_WEBHOOK_VERIFY_TOKEN:
|
||||
return challenge
|
||||
return None
|
||||
|
||||
def verify_signature(self, body: bytes, signature: str) -> bool:
|
||||
if not signature:
|
||||
return False
|
||||
expected = hmac.new(
|
||||
settings.WHATSAPP_API_TOKEN.encode(),
|
||||
body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
||||
async def send_text(self, to: str, text: str) -> bool:
|
||||
if not self.api_token or not self.phone_number_id:
|
||||
logger.warning("WhatsApp not configured")
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messaging_product": "whatsapp",
|
||||
"to": to,
|
||||
"type": "text",
|
||||
"text": {"body": text},
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error(f"WhatsApp send failed: {resp.text}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def send_template(self, to: str, template_name: str, params: Dict[str, str]) -> bool:
|
||||
if not self.api_token or not self.phone_number_id:
|
||||
return False
|
||||
|
||||
components = [
|
||||
{
|
||||
"type": "body",
|
||||
"parameters": [
|
||||
{"type": "text", "text": v} for v in params.values()
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers={"Authorization": f"Bearer {self.api_token}", "Content-Type": "application/json"},
|
||||
json={
|
||||
"messaging_product": "whatsapp",
|
||||
"to": to,
|
||||
"type": "template",
|
||||
"template": {
|
||||
"name": template_name,
|
||||
"language": {"code": "en"},
|
||||
"components": components,
|
||||
},
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
|
||||
def parse_webhook(self, body: Dict) -> Optional[Dict]:
|
||||
try:
|
||||
entry = body.get("entry", [{}])[0]
|
||||
change = entry.get("changes", [{}])[0]
|
||||
value = change.get("value", {})
|
||||
messages = value.get("messages", [])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
msg = messages[0]
|
||||
return {
|
||||
"from": msg.get("from"),
|
||||
"text": msg.get("text", {}).get("body", ""),
|
||||
"msg_id": msg.get("id"),
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"type": msg.get("type", "text"),
|
||||
"profile_name": value.get("contacts", [{}])[0].get("profile", {}).get("name"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse WhatsApp webhook: {e}")
|
||||
return None
|
||||
@@ -0,0 +1,193 @@
|
||||
from datetime import datetime, timedelta
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select, and_
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task
|
||||
def check_silent_customers():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.customer import Customer
|
||||
|
||||
async def _check():
|
||||
async with AsyncSessionLocal() as db:
|
||||
now = datetime.utcnow()
|
||||
for days in [3, 7, 14]:
|
||||
cutoff = now - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(Customer).where(
|
||||
and_(
|
||||
Customer.status.in_(["lead", "negotiating"]),
|
||||
Customer.last_contact_at.isnot(None),
|
||||
Customer.last_contact_at < cutoff,
|
||||
)
|
||||
)
|
||||
)
|
||||
customers = result.scalars().all()
|
||||
for c in customers:
|
||||
if days == 3:
|
||||
logger.info(f"Customer {c.name} silent for 3 days")
|
||||
elif days == 7:
|
||||
logger.info(f"Customer {c.name} silent for 7 days - upgrade")
|
||||
else:
|
||||
logger.info(f"Customer {c.name} silent for 14 days - recommend new approach")
|
||||
|
||||
import asyncio
|
||||
asyncio.run(_check())
|
||||
return "Checked silent customers"
|
||||
|
||||
|
||||
@shared_task
|
||||
def batch_translate_texts(texts: list, target_lang: str, user_id: str):
|
||||
from app.services.translation import TranslationService
|
||||
|
||||
async def _translate():
|
||||
service = TranslationService()
|
||||
results = []
|
||||
for text in texts:
|
||||
result = await service.translate(text, target_lang, user_id=user_id)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_translate())
|
||||
|
||||
|
||||
@shared_task
|
||||
def generate_quotation_pdf(quotation_id: str):
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.quotation import Quotation, QuotationItem
|
||||
|
||||
async def _generate():
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(Quotation).where(Quotation.id == quotation_id)
|
||||
)
|
||||
q = result.scalar_one_or_none()
|
||||
if not q:
|
||||
return {"error": "Quotation not found"}
|
||||
|
||||
items_result = await db.execute(
|
||||
select(QuotationItem).where(QuotationItem.quotation_id == q.id)
|
||||
)
|
||||
items = items_result.scalars().all()
|
||||
|
||||
pdf_content = generate_pdf_text(q, items)
|
||||
|
||||
return {"pdf_content": pdf_content, "quotation_id": str(q.id)}
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_generate())
|
||||
|
||||
|
||||
def generate_pdf_text(quotation, items):
|
||||
from datetime import datetime
|
||||
|
||||
lines = [
|
||||
"=" * 60,
|
||||
f"QUOTATION",
|
||||
f"#{str(quotation.id)[:8].upper()}",
|
||||
"=" * 60,
|
||||
f"Date: {datetime.utcnow().strftime('%Y-%m-%d')}",
|
||||
]
|
||||
|
||||
if quotation.valid_until:
|
||||
lines.append(f"Valid Until: {quotation.valid_until}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"{'Item':<30} {'Qty':<8} {'Unit Price':<12} {'Total':<12}")
|
||||
lines.append("-" * 62)
|
||||
|
||||
for item in items:
|
||||
lines.append(
|
||||
f"{item.product_name:<30} {item.quantity:<8} ${item.unit_price:<10.2f} ${item.total_price:<10.2f}"
|
||||
)
|
||||
|
||||
lines.append("-" * 62)
|
||||
if quotation.subtotal:
|
||||
lines.append(f"{'Subtotal':>48} ${quotation.subtotal:<10.2f}")
|
||||
if quotation.discount:
|
||||
lines.append(f"{'Discount':>48} -${quotation.discount:<10.2f}")
|
||||
if quotation.shipping:
|
||||
lines.append(f"{'Shipping':>48} ${quotation.shipping:<10.2f}")
|
||||
lines.append(f"{'TOTAL':>48} ${quotation.total or quotation.subtotal or 0:<10.2f}")
|
||||
|
||||
lines.append("")
|
||||
if quotation.payment_terms:
|
||||
lines.append(f"Payment Terms: {quotation.payment_terms}")
|
||||
if quotation.delivery_terms:
|
||||
lines.append(f"Delivery Terms: {quotation.delivery_terms}")
|
||||
if quotation.lead_time:
|
||||
lines.append(f"Lead Time: {quotation.lead_time}")
|
||||
if quotation.notes:
|
||||
lines.append(f"Notes: {quotation.notes}")
|
||||
|
||||
lines.append("=" * 60)
|
||||
lines.append("Generated by TradeMate")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@shared_task
|
||||
def process_corpus_quality():
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.corpus import CorpusEntry
|
||||
|
||||
async def _process():
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(CorpusEntry).where(
|
||||
and_(
|
||||
CorpusEntry.quality_score < 0.5,
|
||||
CorpusEntry.usage_count > 5,
|
||||
)
|
||||
).limit(100)
|
||||
)
|
||||
entries = result.scalars().all()
|
||||
for e in entries:
|
||||
e.quality_score = min(1.0, e.quality_score + 0.1)
|
||||
await db.commit()
|
||||
return f"Processed {len(entries)} entries"
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_process())
|
||||
|
||||
|
||||
@shared_task
|
||||
def cleanup_old_sessions():
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
async def _cleanup():
|
||||
r = await aioredis.from_url(settings.REDIS_URL)
|
||||
keys = await r.keys("session:*")
|
||||
if keys:
|
||||
await r.delete(*keys)
|
||||
return f"Cleaned up {len(keys)} sessions"
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_cleanup())
|
||||
|
||||
|
||||
@shared_task
|
||||
def send_followup_reminder(customer_id: str, user_id: str):
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.customer import Customer
|
||||
from app.services.customer import CustomerService
|
||||
|
||||
async def _send():
|
||||
async with AsyncSessionLocal() as db:
|
||||
result = await db.execute(
|
||||
select(Customer).where(
|
||||
and_(Customer.id == customer_id, Customer.user_id == user_id)
|
||||
)
|
||||
)
|
||||
c = result.scalar_one_or_none()
|
||||
if c:
|
||||
logger.info(f"Sending followup reminder for customer {c.name}")
|
||||
return {"customer_id": str(c.id), "customer_name": c.name}
|
||||
return {"error": "Customer not found"}
|
||||
|
||||
import asyncio
|
||||
return asyncio.run(_send())
|
||||
@@ -0,0 +1,10 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
asyncio_mode = auto
|
||||
addopts = -v --tb=short --cov=app --cov-report=term-missing
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
@@ -0,0 +1,19 @@
|
||||
fastapi==0.79.0
|
||||
uvicorn==0.19.0
|
||||
sqlalchemy==1.4.48
|
||||
asyncpg==0.27.0
|
||||
pydantic==1.10.12
|
||||
pydantic-settings==1.1.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-multipart==0.0.6
|
||||
redis==4.5.5
|
||||
celery==5.2.7
|
||||
httpx==0.23.3
|
||||
openai==0.27.8
|
||||
anthropic==0.8.1
|
||||
jinja2==3.1.2
|
||||
alembic==1.11.3
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.main import app
|
||||
from app.database import Base, get_db
|
||||
from app.models.user import User
|
||||
from app.core.security import hash_password
|
||||
|
||||
|
||||
TEST_DATABASE_URL = "postgresql+asyncpg://admin:dWFNi67nHNbPbjmP@localhost:5432/foreign_trade_test"
|
||||
|
||||
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
||||
TestAsyncSessionLocal = sessionmaker(
|
||||
test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with TestAsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://test"
|
||||
) as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(db_session: AsyncSession) -> User:
|
||||
user = User(
|
||||
phone="13800138000",
|
||||
username="test_user",
|
||||
password_hash=hash_password("test123456"),
|
||||
tier="free",
|
||||
)
|
||||
db_session.add(user)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_headers(test_user: User) -> dict:
|
||||
from app.core.security import create_access_token
|
||||
token = create_access_token({"sub": str(test_user.id), "tier": test_user.tier})
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
@@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestAuthAPI:
|
||||
async def test_health_endpoint(self, client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["app"] == "TradeMate"
|
||||
|
||||
async def test_register_new_user(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"phone": "13900139001",
|
||||
"password": "test123456",
|
||||
"username": "newuser",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["phone"] == "13900139001"
|
||||
assert data["username"] == "newuser"
|
||||
assert data["tier"] == "free"
|
||||
|
||||
async def test_register_duplicate_phone(self, client: AsyncClient, test_user):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"phone": "13800138000",
|
||||
"password": "test123456",
|
||||
"username": "duplicate",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "already registered" in response.json()["detail"]
|
||||
|
||||
async def test_login_success(self, client: AsyncClient, test_user):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={
|
||||
"username": "13800138000",
|
||||
"password": "test123456",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
async def test_login_wrong_password(self, client: AsyncClient, test_user):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={
|
||||
"username": "13800138000",
|
||||
"password": "wrongpassword",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_login_nonexistent_user(self, client: AsyncClient):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={
|
||||
"username": "13999999999",
|
||||
"password": "test123456",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_get_current_user(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["phone"] == "13800138000"
|
||||
assert data["username"] == "test_user"
|
||||
|
||||
async def test_get_user_unauthorized(self, client: AsyncClient):
|
||||
response = await client.get("/api/v1/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_refresh_token(self, client: AsyncClient, test_user):
|
||||
from app.core.security import create_refresh_token
|
||||
refresh = create_refresh_token({"sub": str(test_user.id)})
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "access_token" in response.json()
|
||||
@@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def test_app_name(self):
|
||||
assert settings.APP_NAME == "TradeMate"
|
||||
|
||||
def test_jwt_algorithm(self):
|
||||
assert settings.JWT_ALGORITHM == "HS256"
|
||||
|
||||
def test_token_expiration(self):
|
||||
assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 60
|
||||
assert settings.REFRESH_TOKEN_EXPIRE_DAYS == 30
|
||||
|
||||
def test_ai_routing_config(self):
|
||||
assert "translate" in settings.AI_ROUTING
|
||||
assert "reply" in settings.AI_ROUTING
|
||||
assert "marketing" in settings.AI_ROUTING
|
||||
assert settings.AI_ROUTING["translate"]["primary"] == "deepl"
|
||||
assert settings.AI_ROUTING["reply"]["primary"] == "openai"
|
||||
|
||||
def test_free_tier_limits(self):
|
||||
assert settings.FREE_DAILY_TRANSLATE_CHARS == 5000
|
||||
assert settings.FREE_DAILY_REPLIES == 20
|
||||
assert settings.FREE_DAILY_MARKETING == 5
|
||||
assert settings.FREE_MAX_CUSTOMERS == 5
|
||||
assert settings.FREE_MAX_PRODUCTS == 1
|
||||
assert settings.FREE_DAILY_QUOTATIONS == 3
|
||||
|
||||
def test_pro_tier_limits(self):
|
||||
assert settings.PRO_DAILY_TRANSLATE_CHARS == 50000
|
||||
assert settings.PRO_DAILY_REPLIES == 200
|
||||
assert settings.PRO_MAX_CUSTOMERS == 100
|
||||
assert settings.PRO_MAX_PRODUCTS == 20
|
||||
|
||||
def test_database_url_configured(self):
|
||||
assert settings.DATABASE_URL is not None
|
||||
assert "foreign_trade" in settings.DATABASE_URL
|
||||
|
||||
def test_redis_url_configured(self):
|
||||
assert settings.REDIS_URL is not None
|
||||
@@ -0,0 +1,147 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from app.models.customer import Customer
|
||||
import uuid
|
||||
|
||||
|
||||
class TestCustomerAPI:
|
||||
async def test_list_customers_empty(self, client: AsyncClient, auth_headers):
|
||||
response = await client.get("/api/v1/customers", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "items" in data
|
||||
assert data["items"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
async def test_create_customer(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/customers",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name": "John Smith",
|
||||
"company": "ABC Corp",
|
||||
"country": "USA",
|
||||
"phone": "+1234567890",
|
||||
"whatsapp_id": "john123",
|
||||
"email": "john@abc.com",
|
||||
"status": "lead",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "John Smith"
|
||||
assert data["company"] == "ABC Corp"
|
||||
assert data["country"] == "USA"
|
||||
|
||||
async def test_create_customer_minimal(self, client: AsyncClient, auth_headers):
|
||||
response = await client.post(
|
||||
"/api/v1/customers",
|
||||
headers=auth_headers,
|
||||
json={"name": "Minimal Customer"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Minimal Customer"
|
||||
|
||||
async def test_list_customers_with_data(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
customer = Customer(
|
||||
user_id=test_user.id,
|
||||
name="Test Customer",
|
||||
company="Test Co",
|
||||
country="China",
|
||||
status="lead",
|
||||
)
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get("/api/v1/customers", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["items"]) == 1
|
||||
assert data["items"][0]["name"] == "Test Customer"
|
||||
|
||||
async def test_get_customer(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
customer = Customer(
|
||||
user_id=test_user.id,
|
||||
name="Get Test",
|
||||
company="Get Co",
|
||||
status="negotiating",
|
||||
)
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get(
|
||||
f"/api/v1/customers/{customer.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Get Test"
|
||||
|
||||
async def test_get_customer_not_found(self, client: AsyncClient, auth_headers):
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = await client.get(
|
||||
f"/api/v1/customers/{fake_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
async def test_update_customer(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
customer = Customer(
|
||||
user_id=test_user.id,
|
||||
name="Original Name",
|
||||
status="lead",
|
||||
)
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/customers/{customer.id}",
|
||||
headers=auth_headers,
|
||||
json={"name": "Updated Name", "status": "negotiating"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["status"] == "negotiating"
|
||||
|
||||
async def test_delete_customer(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
customer = Customer(
|
||||
user_id=test_user.id,
|
||||
name="To Delete",
|
||||
)
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
customer_id = customer.id
|
||||
|
||||
response = await client.delete(
|
||||
f"/api/v1/customers/{customer_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
get_response = await client.get(
|
||||
f"/api/v1/customers/{customer_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert get_response.status_code == 404
|
||||
|
||||
async def test_get_silent_customers(self, client: AsyncClient, auth_headers, db_session, test_user):
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
customer = Customer(
|
||||
user_id=test_user.id,
|
||||
name="Silent Customer",
|
||||
status="lead",
|
||||
last_contact_at=datetime.utcnow() - timedelta(days=5),
|
||||
)
|
||||
db_session.add(customer)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/customers/silent?days=3",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["count"] >= 1
|
||||
@@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from app.core.exceptions import (
|
||||
TradeMateException,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
ForbiddenError,
|
||||
QuotaExceededError,
|
||||
TierRestrictionError,
|
||||
)
|
||||
|
||||
|
||||
class TestExceptions:
|
||||
def test_trade_mate_exception(self):
|
||||
exc = TradeMateException(400, "Bad Request", "Details")
|
||||
assert exc.code == 400
|
||||
assert exc.message == "Bad Request"
|
||||
assert exc.detail == "Details"
|
||||
|
||||
def test_not_found_error(self):
|
||||
exc = NotFoundError("User")
|
||||
assert exc.code == 404
|
||||
assert "User" in exc.message
|
||||
assert "not found" in exc.message
|
||||
|
||||
def test_unauthorized_error(self):
|
||||
exc = UnauthorizedError()
|
||||
assert exc.code == 401
|
||||
assert exc.message == "Unauthorized"
|
||||
|
||||
def test_forbidden_error(self):
|
||||
exc = ForbiddenError()
|
||||
assert exc.code == 403
|
||||
assert exc.message == "Forbidden"
|
||||
|
||||
def test_quota_exceeded_error(self):
|
||||
exc = QuotaExceededError("translation")
|
||||
assert exc.code == 429
|
||||
assert "Quota exceeded" in exc.message
|
||||
assert "translation" in exc.detail
|
||||
|
||||
def test_tier_restriction_error(self):
|
||||
exc = TierRestrictionError("Advanced Feature", "Pro")
|
||||
assert exc.code == 402
|
||||
assert "Upgrade required" in exc.message
|
||||
assert "Pro" in exc.detail
|
||||
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from app.core.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
)
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
def test_hash_password(self):
|
||||
pwd = "test123456"
|
||||
hashed = hash_password(pwd)
|
||||
assert hashed != pwd
|
||||
assert verify_password(pwd, hashed)
|
||||
|
||||
def test_verify_password_wrong(self):
|
||||
pwd = "test123456"
|
||||
hashed = hash_password(pwd)
|
||||
assert not verify_password("wrongpassword", hashed)
|
||||
|
||||
def test_create_access_token(self):
|
||||
data = {"sub": "test-user-id", "tier": "free"}
|
||||
token = create_access_token(data)
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
|
||||
def test_decode_token_valid(self):
|
||||
data = {"sub": "test-user-id", "tier": "pro"}
|
||||
token = create_access_token(data)
|
||||
decoded = decode_token(token)
|
||||
assert decoded is not None
|
||||
assert decoded["sub"] == "test-user-id"
|
||||
assert decoded["tier"] == "pro"
|
||||
|
||||
def test_decode_token_invalid(self):
|
||||
decoded = decode_token("invalid-token")
|
||||
assert decoded is None
|
||||
|
||||
def test_create_refresh_token(self):
|
||||
data = {"sub": "test-user-id"}
|
||||
token = create_refresh_token(data)
|
||||
assert token is not None
|
||||
decoded = decode_token(token)
|
||||
assert decoded is not None
|
||||
assert decoded["type"] == "refresh"
|
||||
Reference in New Issue
Block a user