7c9885f704
- alibaba.py: remove translated text content from log, only log char count - nvidia.py: remove messages content reference from timing log - push.py: replace full content with length, sanitize WeChat error response
52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from typing import Dict, Any, Optional, List
|
|
from app.ai.providers.openai import OpenAIProvider, SYSTEM_PROMPTS
|
|
import logging
|
|
import time
|
|
import httpx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NvidiaProvider(OpenAIProvider):
|
|
def __init__(self, api_key: str, model: str = "stepfun-ai/step-3.5-flash", base_url: str = "https://integrate.api.nvidia.com/v1"):
|
|
super().__init__(
|
|
api_key=api_key,
|
|
model=model,
|
|
base_url=base_url,
|
|
http_client=httpx.AsyncClient(timeout=httpx.Timeout(20.0)),
|
|
)
|
|
self._name = f"nvidia-{model}"
|
|
|
|
async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]:
|
|
t0 = time.time()
|
|
|
|
system = system_prompt or SYSTEM_PROMPTS["chat"]
|
|
messages = [{"role": "system", "content": system}]
|
|
if history:
|
|
for h in history[-10:]:
|
|
messages.append(h)
|
|
messages.append({"role": "user", "content": message})
|
|
t1 = time.time()
|
|
|
|
max_tokens = 800 if "JSON" in (system or "").upper() else 300
|
|
kwargs = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": 0.3,
|
|
}
|
|
resp = await self.client.chat.completions.create(**kwargs)
|
|
t2 = time.time()
|
|
|
|
content = resp.choices[0].message.content or ""
|
|
if not content and hasattr(resp.choices[0].message, "reasoning"):
|
|
content = resp.choices[0].message.reasoning
|
|
t3 = time.time()
|
|
|
|
logger.info(
|
|
f"NVIDIA timing: build_msgs={t1-t0:.1f}s api_call={t2-t1:.1f}s process={t3-t2:.1f}s "
|
|
f"chars_out={len(content)}"
|
|
)
|
|
|
|
return {"reply": content, "provider": self.name, "model": self.model}
|