Handling "InterfaceError: another operation is in progress" with Async SQLAlchemy and FastAPI

48 Views Asked by At

I'm developing a FastAPI application where I use SQLModel with the asyncpg driver for asynchronous database operations. Despite following the asynchronous patterns and ensuring proper await usage on database calls, I encounter the following error during my pytest tests:

InterfaceError: cannot perform operation: another operation is in progress

This error arises when executing database operations, seemingly due to concurrent access or overlapping database transactions. I've tried ensuring that each test and request uses its own AsyncSession and that all sessions and transactions are properly closed and committed.

import random
import string

import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport

from main import app  # Make sure this import points to your FastAPI app instance


@pytest_asyncio.fixture
async def client():
    async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
        yield client


@pytest_asyncio.fixture
def generate_random_phone_number():
    def _generate(length=10):
        return ''.join(random.choices(string.digits, k=length))

    return _generate


@pytest_asyncio.fixture
def generate_random_phone_prefix():
    def _generate():
        prefix_length = random.randint(1, 3)
        return '+' + ''.join(random.choices(string.digits, k=prefix_length))

    return _generate


@pytest.mark.asyncio
async def test_create_user(client: AsyncClient, generate_random_phone_number, generate_random_phone_prefix):
    user_data = {
        "phone_number": generate_random_phone_number(),
        "phone_prefix": generate_random_phone_prefix()
    }

    response = await client.post("/api/user/", json=user_data)

    assert response.status_code == 201
    data = response.json()
    assert data["phone_number"] == user_data["phone_number"]
    assert data["phone_prefix"] == user_data["phone_prefix"]


@pytest.mark.asyncio
async def test_duplicate_user(client: AsyncClient, generate_random_phone_number, generate_random_phone_prefix):
    phone_number = generate_random_phone_number()
    phone_prefix = generate_random_phone_prefix()
    user_data = {
        "phone_number": phone_number,
        "phone_prefix": phone_prefix
    }

    await client.post("/api/user/", json=user_data)
    response = await client.post("/api/user/", json=user_data)

    assert response.status_code == 400
    data = response.json()
    assert data["detail"] == "A user with the given phone number and prefix already exists."


@pytest.mark.asyncio
async def test_create_interest(client: AsyncClient, generate_random_phone_number, generate_random_phone_prefix):
    # First, create a user
    user_data = {
        "phone_number": generate_random_phone_number(),
        "phone_prefix": generate_random_phone_prefix()
    }
    user_response = await client.post("/api/user/", json=user_data)
    assert user_response.status_code == 201
    user = user_response.json()

    interest_data = {
        "topic": "Sample Topic",
        "found": 1,
        "search": True
    }

    headers = {"User-ID": str(user["id"])}

    response = await client.post("/api/interest/", json=interest_data, headers=headers)

    assert response.status_code == 201
    interest = response.json()
    assert interest["topic"] == interest_data["topic"]
    assert interest["found"] == interest_data["found"]
    assert interest["search"] == interest_data["search"]


@pytest.mark.asyncio
async def test_get_interest(client: AsyncClient, generate_random_phone_number, generate_random_phone_prefix):
    user_data = {
        "phone_number": generate_random_phone_number(),
        "phone_prefix": generate_random_phone_prefix()
    }
    user_response = await client.post("/api/user/", json=user_data)
    assert user_response.status_code == 201
    user = user_response.json()
    headers = {"User-ID": str(user["id"])}

    interest_1 = {
        "topic": "Sample Topic1",
        "found": 1,
        "search": True
    }
    response = await client.post("/api/interest/", json=interest_1, headers=headers)
    assert response.status_code == 201

    interest_2 = {
        "topic": "Sample Topic2",
        "found": 0,
        "search": True
    }

    response = await client.post("/api/interest/", json=interest_2, headers=headers)
    assert response.status_code == 201

    response = await client.get("/api/interest/", headers=headers)
    interests = response.json()
    assert len(interests) == 2
    # Validate the content of the first interest object
    interest_1_response = interests[0]
    assert interest_1_response["topic"] == "Sample Topic1"
    assert interest_1_response["found"] == 1
    assert interest_1_response["search"] is True
    assert "created_at" in interest_1_response
    assert "updated_at" in interest_1_response
    assert interest_1_response["created_at"] <= interest_1_response["updated_at"]

    # Validate the content of the second interest object
    interest_2_response = interests[1]
    assert interest_2_response["topic"] == "Sample Topic2"
    assert interest_2_response["found"] == 0
    assert interest_2_response["search"] is True
    assert "created_at" in interest_2_response
    assert "updated_at" in interest_2_response
    assert interest_2_response["created_at"] <= interest_2_response["updated_at"]

I've also ensured that my AsyncClient for testing is properly set up and that each test function is marked with @pytest.mark.asyncio to run in an async context.

I'm looking for insights or solutions to properly handle this error and ensure that my asynchronous database operations don't conflict with each other.

Update Here are all the codes. Endpoints:

@asynccontextmanager
async def lifespan():
    await init_db()


app = FastAPI(lifespan=lifespan)
allowed_origins = [
    "http://127.0.0.1:5173",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=allowed_origins,  # List of allowed origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)


def user_id_from_header(user_id: str = Header(...)) -> str:
    if not user_id:
        raise HTTPException(status_code=400, detail="User-ID is missing")
    return user_id


# Health
@app.get("/")
async def health():
    return {"health": "ok"}


# User
@app.get("/api/user/{user_id}", status_code=status.HTTP_200_OK, response_model=UserRead)
async def get_user(*, session: AsyncSession = Depends(get_session), user_id: uuid.UUID):
    user = await session.get(User, user_id)
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user


@app.post("/api/user/", response_model=UserRead, status_code=status.HTTP_201_CREATED)
async def create_user(*, user_create: UserCreate, session: AsyncSession = Depends(get_session)):
    existing_user = await session.exec(
        select(User).where(
            User.phone_number == user_create.phone_number,
            User.phone_prefix == user_create.phone_prefix
        )
    )
    if existing_user.first():
        raise HTTPException(
            status_code=400,
            detail="A user with the given phone number and prefix already exists."
        )

    db_user = User.model_validate(user_create)
    session.add(db_user)
    await session.commit()
    await session.refresh(db_user)
    return db_user


# Interest

@app.post("/api/interest/", response_model=InterestRead, status_code=status.HTTP_201_CREATED)
async def create_interest(*, interest_create: InterestCreate, session: AsyncSession = Depends(get_session),
                          user_id: uuid.UUID = Depends(user_id_from_header)):
    user = await session.get(User, user_id)
    if not user:
        raise HTTPException(status_code=404, detail="User not found")

    db_interest = Interest(**interest_create.model_dump(), user=user)
    session.add(db_interest)
    await session.commit()
    await session.refresh(db_interest)
    return db_interest


@app.get("/api/interest/", response_model=List[InterestRead])
async def read_interests(*, user_id: uuid.UUID = Depends(user_id_from_header),
                         session: AsyncSession = Depends(get_session)):
    interests = select(Interest).where(Interest.user_id == user_id)
    results = await session.exec(interests)
    return results

DB connector

engine = create_async_engine("postgresql+asyncpg://xxxxxxxxx:[email protected]:5432/dobotsvc", echo=True, future=True)


async def init_db():
    async with engine.begin() as conn:
        # await conn.run_sync(SQLModel.metadata.drop_all)
        await conn.run_sync(SQLModel.metadata.create_all)


async def get_session() -> AsyncSession:
    async_session = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
    async with async_session() as session:
        yield session

and models

class UserBase(SQLModel):
    id: UUID = Field(default_factory=uuid4, primary_key=True)
    phone_number: str = Field(max_length=255)
    phone_prefix: str = Field(max_length=10)


class User(UserBase, table=True):
    __table_args__ = (
        UniqueConstraint("phone_number", "phone_prefix", name="phone_numbe_phone_prefix_constraint"),
    )
    registered_at: datetime = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False),
                                    default_factory=lambda: datetime.now(timezone.utc))
    interests: List["Interest"] = Relationship(back_populates="user")


class UserRead(UserBase):
    pass


class UserCreate(UserBase):
    pass


class InterestBase(SQLModel):
    id: Optional[int] = Field(default=None, primary_key=True)
    topic: str = Field(max_length=100)
    found: int = 0
    search: bool = Field(default=False)
    created_at: datetime = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False),
                                 default_factory=lambda: datetime.now(timezone.utc))
    updated_at: datetime = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False),
                                 default_factory=lambda: datetime.now(timezone.utc))


class Interest(InterestBase, table=True):
    user_id: UUID = Field(foreign_key="user.id")
    user: User = Relationship(back_populates="interests")

    proposals: List["Proposal"] = Relationship(back_populates="interest")


class InterestCreate(InterestBase):
    pass


class InterestRead(InterestBase):
    pass


class ProposalBase(SQLModel):
    id: Optional[int] = Field(default=None, primary_key=True)
    interest_id: int = Field(foreign_key="interest.id")
    created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    text: str


class Proposal(ProposalBase, table=True):
    interest: Interest = Relationship(back_populates="proposals")
0

There are 0 best solutions below