Setting up transactional tests with Pytest and SQLAlchemy

Setting up transactional tests with Pytest and SQLAlchemy

Coming from a Ruby on Rails background, I really appreciate solutions that became community standards of solving problems. One of such things is DatabaseCleaner gem, which ensures that your tests run in separation and there are no data leaks between them.

Recently I was looking for a similar solution for Python, but, to my surprise, I haven’t found any. However, it’s easy to leverage what SQLAlchemy and Pytest offer to wrap tests in separate database transactions. Let me show you a neat take on the problem that I hope you will find convenient to use.

Setting up a DB connection

Before we’re able to use a database with transactions in our tests, we need to set up a separate DB instance exclusively for tests. Then, from our test suite, we need to connect to the DB. Here’s an example of how to establish a connection with a MySQL database:

@pytest.fixture(scope="session")
def connection():
    engine = create_engine(
        "mysql+mysqldb://{}:{}@{}:{}/{}".format(
            os.environ.get('TEST_DB_USER'),
            os.environ.get('TEST_DB_PASSWORD'),
            os.environ.get('TEST_DB_HOST'),
            os.environ.get('TEST_DB_PORT'),
            os.environ.get('TEST_DB_NAME'),
        )
    )
    return engine.connect()

If you use another database engine, head to SQLAlchemy documentation for information on how to build different connection strings.

Table creation and DB seeding

Now we need to recreate our database structure. Let’s assume that all models in your app are declared in the models.py file.

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    id = Column(INTEGER, primary_key=True)
    name = Column(VARCHAR(64))
    # ...

SQLAlchemy offers methods to easily create and drop tables declared in the schema: create_all and drop_all. We will use them at the beginning of the test suite execution to ensure that all tables are in place. After a full test run, we will drop all tables so that the next execution can start with a clean slate.

@pytest.fixture(scope="session")
def setup_database(connection):
    models.Base.metadata.bind = connection
    models.Base.metadata.create_all()

    seed_database()

    yield

    models.Base.metadata.drop_all()

If you need the database to be pre-configured with some data, you can run a method seeding the database. A simple example would be as follows:

def seed_database():
    users = [
        {
            "id": 1,
            "name": "John Doe",
        },
        # ...
    ]

    for user in users:
        db_user = User(**user)
        db_session.add(db_user)
    db_session.commit()

Wrapping tests in transactions

As a final step, we need to establish a way to use transactions in our test suite. Hence, we will build a fixture that creates a new transaction for each test.

@pytest.fixture
def db_session(setup_database, connection):
    transaction = connection.begin()
    yield scoped_session(
        sessionmaker(autocommit=False, autoflush=False, bind=connection)
    )
    transaction.rollback()

You can then inject the fixture into your test cases. At the end of each test execution, all data created will be wiped out, ensuring test case separation.

def test_user_created(db_session):
    # ...
    db_session.add(User(name="Jane Doe"))
    db_session.commit()
    # ...

Summary

It doesn’t take much time to set up working transactions with Pytest and SQLAlchemy once you know how to do it. I hope that you will find this solution neat and easy to use. Of course, if you have any ideas to share or improvements to propose, let me know in the comments — I’m always happy to learn something new.

For future reference, below is a full code that you can reuse in your projects. Good luck, and have fun writing tests with transactions!

@pytest.fixture(scope="session")
def connection():
    engine = create_engine(
        "mysql+mysqldb://{}:{}@{}:{}/{}".format(
            os.environ.get('TEST_DB_USER'),
            os.environ.get('TEST_DB_PASSWORD'),
            os.environ.get('TEST_DB_HOST'),
            os.environ.get('TEST_DB_PORT'),
            os.environ.get('TEST_DB_NAME'),
        )
    )
    return engine.connect()


def seed_database():
    users = [
        {
            "id": 1,
            "name": "John Doe",
        },
        # ...
    ]

    for user in users:
        db_user = User(**user)
        db_session.add(db_user)
    db_session.commit()


@pytest.fixture(scope="session")
def setup_database(connection):
    models.Base.metadata.bind = connection
    models.Base.metadata.create_all()

    seed_database()

    yield

    models.Base.metadata.drop_all()


@pytest.fixture
def db_session(setup_database, connection):
    transaction = connection.begin()
    yield scoped_session(
        sessionmaker(autocommit=False, autoflush=False, bind=connection)
    )
    transaction.rollback()

Did you find this article valuable?

Support Damian Kampik by becoming a sponsor. Any amount is appreciated!