169 lines
5.0 KiB
Python
169 lines
5.0 KiB
Python
import pytest
|
|
import asyncio
|
|
from typing import AsyncGenerator
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
# Mock aliyunsdkalimt before importing app.main
|
|
import types
|
|
|
|
aliyunsdkalimt = types.ModuleType('aliyunsdkalimt')
|
|
aliyunsdkalimt.__path__ = ['/tmp/mock_aliyunsdkalimt']
|
|
sys.modules['aliyunsdkalimt'] = aliyunsdkalimt
|
|
|
|
aliyunsdkalimt_request = types.ModuleType('aliyunsdkalimt.request')
|
|
aliyunsdkalimt_request.__path__ = ['/tmp/mock_aliyunsdkalimt/request']
|
|
sys.modules['aliyunsdkalimt.request'] = aliyunsdkalimt_request
|
|
|
|
aliyunsdkalimt_request_v20181012 = types.ModuleType('aliyunsdkalimt.request.v20181012')
|
|
sys.modules['aliyunsdkalimt.request.v20181012'] = aliyunsdkalimt_request_v20181012
|
|
|
|
class TranslateGeneralRequest:
|
|
def __init__(self):
|
|
self.source_text = None
|
|
self.source_language = None
|
|
self.target_language = None
|
|
self.scene = None
|
|
|
|
def setSourceText(self, text):
|
|
self.source_text = text
|
|
|
|
def setSourceLanguage(self, lang):
|
|
self.source_language = lang
|
|
|
|
def setTargetLanguage(self, lang):
|
|
self.target_language = lang
|
|
|
|
def setScene(self, scene):
|
|
self.scene = scene
|
|
|
|
|
|
class TranslateECommerceRequest:
|
|
def __init__(self):
|
|
self.source_text = None
|
|
self.source_language = None
|
|
self.target_language = None
|
|
self.scene = None
|
|
|
|
def setSourceText(self, text):
|
|
self.source_text = text
|
|
|
|
def setSourceLanguage(self, lang):
|
|
self.source_language = lang
|
|
|
|
def setTargetLanguage(self, lang):
|
|
self.target_language = lang
|
|
|
|
def setScene(self, scene):
|
|
self.scene = scene
|
|
|
|
aliyunsdkalimt_request_v20181012.TranslateGeneralRequest = TranslateGeneralRequest
|
|
aliyunsdkalimt_request_v20181012.TranslateECommerceRequest = TranslateECommerceRequest
|
|
|
|
# Mock AcsClient
|
|
aliyunsdkcore = types.ModuleType('aliyunsdkcore')
|
|
aliyunsdkcore_client = types.ModuleType('aliyunsdkcore.client')
|
|
aliyunsdkcore_auth = types.ModuleType('aliyunsdkcore.auth')
|
|
aliyunsdkcore_auth_credentials = types.ModuleType('aliyunsdkcore.auth.credentials')
|
|
|
|
class AcsClient:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def do_action(self, request):
|
|
return b'{"TranslateResult": "mock translation"}'
|
|
|
|
class AccessKeyCredential:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
aliyunsdkcore_client.AcsClient = AcsClient
|
|
aliyunsdkcore_auth_credentials.AccessKeyCredential = AccessKeyCredential
|
|
sys.modules['aliyunsdkcore'] = aliyunsdkcore
|
|
sys.modules['aliyunsdkcore.client'] = aliyunsdkcore_client
|
|
sys.modules['aliyunsdkcore.auth'] = aliyunsdkcore_auth
|
|
sys.modules['aliyunsdkcore.auth.credentials'] = aliyunsdkcore_auth_credentials
|
|
|
|
from app.main import app
|
|
from app.database import Base, get_db
|
|
from app.models.user import User
|
|
from app.core.security import hash_password
|
|
|
|
|
|
TEST_DATABASE_URL = "postgresql+asyncpg://admin:dWFNi67nHNbPbjmP@localhost:5432/foreign_trade_test"
|
|
|
|
test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
|
TestAsyncSessionLocal = sessionmaker(
|
|
test_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def event_loop():
|
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async with TestAsyncSessionLocal() as session:
|
|
yield session
|
|
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
|
async def override_get_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test"
|
|
) as ac:
|
|
yield ac
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_user(db_session: AsyncSession) -> User:
|
|
user = User(
|
|
phone="13800138000",
|
|
username="test_user",
|
|
password_hash=hash_password("test123456"),
|
|
tier="free",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
await db_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
async def auth_headers(test_user: User) -> dict:
|
|
from app.core.security import create_access_token
|
|
token = create_access_token({"sub": str(test_user.id), "tier": test_user.tier})
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# Mark all async test functions with pytest.mark.asyncio
|
|
def pytest_collection_modifyitems(items):
|
|
for item in items:
|
|
if hasattr(item, 'function') and asyncio.iscoroutinefunction(item.function):
|
|
item.add_marker(pytest.mark.asyncio)
|