3

新しい (ベータ版) 1.4 sqlalchemy を試していますが非同期 API とpytest.

まず、zzzeekunittest例を に変換してみました。pytest

import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session")
def engine_fixture():
    engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    yield engine

    Base.metadata.drop_all(engine)


@pytest.fixture
def session(engine_fixture):
    conn = engine_fixture.connect()
    trans = conn.begin()
    session = Session(bind=conn)

    def _fixture(session):
        session.add_all([Thing(), Thing(), Thing()])
        session.commit()

    # load fixture data within the scope of the transaction
    _fixture(session)

    # start the session in a SAVEPOINT...
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @event.listens_for(session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    session.close()
    trans.rollback()
    conn.close()


def _test_thing(session, extra_rollback=0):

    rows = session.query(Thing).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        assert len(rows) == 6

        session.rollback()

    # after rollbacks, still @ 3 rows
    rows = session.query(Thing).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    session.commit()

    rows = session.query(Thing).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = session.query(Thing).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        session.rollback()

    rows = session.query(Thing).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


def test_thing_one_pytest(session):
    # run zero rollbacks
    _test_thing(session, 0)


def test_thing_two_pytest(session):
    # run two extra rollbacks
    _test_thing(session, 2)

次に、バージョン0.14.0asyncioを使用してAPIに切り替えてみましたpytest-asyncio

import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session", autouse=True)
def meta_migration():
    # setup
    sync_engine = create_engine(
        "postgresql://postgres:changethis@db/app_test", echo=True
    )
    Base.metadata.drop_all(sync_engine)
    Base.metadata.create_all(sync_engine)

    yield sync_engine

    # teardown
    Base.metadata.drop_all(sync_engine)


@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
    # setup
    engine = create_async_engine(
        "postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
    )

    yield engine


@pytest.fixture(scope="function")
async def session(async_engine):
    conn = await async_engine.connect()
    trans = await conn.begin()
    session = AsyncSession(bind=conn)

    async def _fixture(session: AsyncSession):
        session.add_all([Thing(), Thing(), Thing()])
        await session.commit()

    # load fixture data within the scope of the transaction
    await _fixture(session)

    # start the session in a SAVEPOINT...
    await session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    # NOTE: no async listeners yet
    @event.listens_for(session.sync_session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    await session.close()
    await trans.rollback()
    await conn.close()


async def _test_thing(session: AsyncSession, extra_rollback=0):

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        assert len(rows) == 6

        await session.rollback()

    # after rollbacks, still @ 3 rows
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    await session.commit()

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        await session.rollback()

    rows = (await session.execute(select(Thing))).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


@pytest.mark.asyncio
async def test_thing_one_pytest(session):
    # run zero rollbacks
    await _test_thing(session, 0)


@pytest.mark.asyncio
async def test_thing_two_pytest(session):
    # run two extra rollbacks
    await _test_thing(session, 2)

"FAILED test_thing_two_pytest - assert 8 == 3"ただし、最初のテスト後のトランザクション ロールバックは、フェーズteardownで作成された SAVEPOINT に復元されないため、これは失敗します。setup

sqlalchemy の内部構造に関する私の知識はそれほど多くないため、テスト スイートのパフォーマンスにとって非常に重要であるため、これをセットアップする際に助けを求めています。

asyncイベントリスナーの欠落と定義だけrestart_savepointでは十分でAsyncSession.sync_sessionはなく、1.4 API の安定したリリースを待つ必要があるのでしょうか?

ありがとう!

4

1 に答える 1