First, you should rather not load your model every time a request arrives, but rahter have it loaded once at startup (you could use the startup event for this) and store it on the app instance—using the generic app.state
attribute (see implementation of State too)—which you can later retrieve, as described here and here. For instance:
from fastapi import Request
@app.on_event("startup")
async def startup_event():
app.state.model = torch.load('<model_path>')
Second, if you do not have any async
functions inside your endpoint that you have to await
, you could define your endpoint with def
instead of async def
. In this way, FastAPI will process the requests concurrently, as each request will run in a separate thread; whereas, async def
endpoints run on the main thread, i.e., the server processes the requests sequentially, as long as there is no await
call to some CPU/IO-bound (blocking) operation inside such routes. If so, the keyword await
would pass function control back to the event loop, thus allowing other tasks/requests in the event loop to run. Please have a look at the answers here and here, as well as all the references included in them, to understand the concept of async
/await
, as well as the difference between using def
and async def
. Example with def
endpoint:
@app.post("https://stackoverflow.com/")
def your_endpoint(request: Request):
model = request.app.state.model
# run your synchronous ask_query() function here
Alternatively, as described here, you could, preferably, run your CPU-bound task in a separate process, using ProcessPoolExecutor
, and integrate with asyncio
, in order to await
it to finish its work and return the result(s)—in this case, you would need to define your endpoint with async def
, as the await
keyword only works within an async
function. Beware that it is important to protect the main loop of code to avoid recursive spawning of subprocesses, etc.; that is, your code must be under if __name__ == '__main__'
. Example:
from fastapi import FastAPI, Request
import concurrent.futures
import asyncio
import uvicorn
class MyAIClass():
def __init__(self) -> None:
super().__init__()
def ask_query(self, model, query, topN):
# ...
ai = MyAIClass()
app = FastAPI()
@app.on_event("startup")
async def startup_event():
app.state.model = torch.load('<model_path>')
@app.post("https://stackoverflow.com/")
async def your_endpoint(request: Request):
model = request.app.state.model
loop = asyncio.get_running_loop()
with concurrent.futures.ProcessPoolExecutor() as pool:
res = await loop.run_in_executor(pool, ai.ask_query, model, item.text, item.topN)
if __name__ == '__main__':
uvicorn.run(app)
Note that if you plan on having several workers active at the same time, each worker has its own memory—in other words, workers do not share the same memory—and hence, each worker will load their own instance of the ML model into memory (RAM). If, for instance, you are using four workers for your app, the model will result in being loaded four times into RAM. Thus, if the model, as well as other variables in your code, are consuming a large amount of memory, each process/worker will consume an equivalent amount of memory. If you would like to avoid that, you may have a look at how to share objects across multiple workers, as well as—if you are using Gunicorn as a process manager with Uvicorn workers—you can use Gunicorn’s --preload
flag. As per the documentation:
Command line:
--preload
Default:
False
Load application code before the worker processes are forked.
By preloading an application you can save some RAM resources as well
as speed up server boot times. Although, if you defer application
loading to each worker process, you can reload your application code
easily by restarting workers.
Example:
gunicorn --workers 4 --preload --worker-class=uvicorn.workers.UvicornWorker app:app
Note that you cannot combine Gunicorn’s --preload
with --reload
flag, as when the code is preloaded into the master process, the new worker processes—which will automatically be created, if your application code has changed—will still have the old code in memory, due to how fork()
works.