新しい (ベータ版) 1.4 sqlalchemy を試していますが、非同期 API とpytest
.
まず、zzzeekのunittest
例を に変換してみました。pytest
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
次に、バージョン0.14.0asyncio
を使用してAPIに切り替えてみましたpytest-asyncio
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
"FAILED test_thing_two_pytest - assert 8 == 3"
ただし、最初のテスト後のトランザクション ロールバックは、フェーズteardown
で作成された SAVEPOINT に復元されないため、これは失敗します。setup
sqlalchemy の内部構造に関する私の知識はそれほど多くないため、テスト スイートのパフォーマンスにとって非常に重要であるため、これをセットアップする際に助けを求めています。
async
イベントリスナーの欠落と定義だけrestart_savepoint
では十分でAsyncSession.sync_session
はなく、1.4 API の安定したリリースを待つ必要があるのでしょうか?
ありがとう!