Coming from a Ruby on Rails background, I really appreciate solutions that became community standards of solving problems. One of such things is DatabaseCleaner gem, which ensures that your tests run in separation and there are no data leaks between them.
Recently I was looking for a similar solution for Python, but, to my surprise, I haven’t found any. However, it’s easy to leverage what SQLAlchemy and Pytest offer to wrap tests in separate database transactions. Let me show you a neat take on the problem that I hope you will find convenient to use.
Setting up a DB connection
Before we’re able to use a database with transactions in our tests, we need to set up a separate DB instance exclusively for tests. Then, from our test suite, we need to connect to the DB. Here’s an example of how to establish a connection with a MySQL database:
@pytest.fixture(scope="session")
def connection():
engine = create_engine(
"mysql+mysqldb://{}:{}@{}:{}/{}".format(
os.environ.get('TEST_DB_USER'),
os.environ.get('TEST_DB_PASSWORD'),
os.environ.get('TEST_DB_HOST'),
os.environ.get('TEST_DB_PORT'),
os.environ.get('TEST_DB_NAME'),
)
)
return engine.connect()
If you use another database engine, head to SQLAlchemy documentation for information on how to build different connection strings.
Table creation and DB seeding
Now we need to recreate our database structure. Let’s assume that all models in your app are declared in the models.py
file.
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(INTEGER, primary_key=True)
name = Column(VARCHAR(64))
# ...
SQLAlchemy offers methods to easily create and drop tables declared in the schema: create_all
and drop_all
. We will use them at the beginning of the test suite execution to ensure that all tables are in place. After a full test run, we will drop all tables so that the next execution can start with a clean slate.
@pytest.fixture(scope="session")
def setup_database(connection):
models.Base.metadata.bind = connection
models.Base.metadata.create_all()
seed_database()
yield
models.Base.metadata.drop_all()
If you need the database to be pre-configured with some data, you can run a method seeding the database. A simple example would be as follows:
def seed_database():
users = [
{
"id": 1,
"name": "John Doe",
},
# ...
]
for user in users:
db_user = User(**user)
db_session.add(db_user)
db_session.commit()
Wrapping tests in transactions
As a final step, we need to establish a way to use transactions in our test suite. Hence, we will build a fixture that creates a new transaction for each test.
@pytest.fixture
def db_session(setup_database, connection):
transaction = connection.begin()
yield scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=connection)
)
transaction.rollback()
You can then inject the fixture into your test cases. At the end of each test execution, all data created will be wiped out, ensuring test case separation.
def test_user_created(db_session):
# ...
db_session.add(User(name="Jane Doe"))
db_session.commit()
# ...
Summary
It doesn’t take much time to set up working transactions with Pytest and SQLAlchemy once you know how to do it. I hope that you will find this solution neat and easy to use. Of course, if you have any ideas to share or improvements to propose, let me know in the comments — I’m always happy to learn something new.
For future reference, below is a full code that you can reuse in your projects. Good luck, and have fun writing tests with transactions!
@pytest.fixture(scope="session")
def connection():
engine = create_engine(
"mysql+mysqldb://{}:{}@{}:{}/{}".format(
os.environ.get('TEST_DB_USER'),
os.environ.get('TEST_DB_PASSWORD'),
os.environ.get('TEST_DB_HOST'),
os.environ.get('TEST_DB_PORT'),
os.environ.get('TEST_DB_NAME'),
)
)
return engine.connect()
def seed_database():
users = [
{
"id": 1,
"name": "John Doe",
},
# ...
]
for user in users:
db_user = User(**user)
db_session.add(db_user)
db_session.commit()
@pytest.fixture(scope="session")
def setup_database(connection):
models.Base.metadata.bind = connection
models.Base.metadata.create_all()
seed_database()
yield
models.Base.metadata.drop_all()
@pytest.fixture
def db_session(setup_database, connection):
transaction = connection.begin()
yield scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=connection)
)
transaction.rollback()