Files
trade-assistant/backend/app/ai/providers/nvidia.py
T

51 lines
1.8 KiB
Python

from typing import Dict, Any, Optional, List
from app.ai.providers.openai import OpenAIProvider, SYSTEM_PROMPTS
import logging
import time
import httpx
logger = logging.getLogger(__name__)
class NvidiaProvider(OpenAIProvider):
def __init__(self, api_key: str, model: str = "stepfun-ai/step-3.5-flash", base_url: str = "https://integrate.api.nvidia.com/v1"):
super().__init__(
api_key=api_key,
model=model,
base_url=base_url,
http_client=httpx.AsyncClient(timeout=httpx.Timeout(20.0)),
)
self._name = f"nvidia-{model}"
async def chat(self, message: str, history: list = None, system_prompt: str = None) -> Dict[str, Any]:
t0 = time.time()
system = system_prompt or SYSTEM_PROMPTS["chat"]
messages = [{"role": "system", "content": system}]
if history:
for h in history[-10:]:
messages.append(h)
messages.append({"role": "user", "content": message})
t1 = time.time()
kwargs = {
"model": self.model,
"messages": messages,
"max_tokens": 300,
"temperature": 0.3,
}
resp = await self.client.chat.completions.create(**kwargs)
t2 = time.time()
content = resp.choices[0].message.content or ""
if not content and hasattr(resp.choices[0].message, "reasoning"):
content = resp.choices[0].message.reasoning
t3 = time.time()
logger.info(
f"NVIDIA timing: build_msgs={t1-t0:.1f}s api_call={t2-t1:.1f}s process={t3-t2:.1f}s "
f"chars_in={sum(len(m.get('content','')) for m in messages)} chars_out={len(content)}"
)
return {"reply": content, "provider": self.name, "model": self.model}