add limiter to endpoint

This commit is contained in:
MuslemRahimi 2024-07-19 09:21:56 +02:00
parent abbe1b57d0
commit ece28efd52

View File

@ -15,7 +15,7 @@ import aiohttp
import pytz
import redis
from dotenv import load_dotenv
from pydantic import BaseModel
from pydantic import BaseModel, Field
from benzinga import financial_data
# Database related imports
@ -24,13 +24,17 @@ from contextlib import contextmanager
from pocketbase import PocketBase
# FastAPI and related imports
from fastapi import FastAPI, Depends, HTTPException, Security, status
from fastapi import FastAPI, Depends, HTTPException, Security, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security.api_key import APIKeyHeader
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# DB constants & context manager
@ -123,7 +127,10 @@ api_key = os.getenv('FMP_API_KEY')
benzinga_key = os.getenv('BENZINGA_API_KEY')
fin = financial_data.Benzinga(benzinga_key)
app = FastAPI(docs_url=None, redoc_url=None, openapi_url = None)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
origins = ["http://www.stocknear.com","https://www.stocknear.com","http://stocknear.com","https://stocknear.com","http://localhost:5173","http://localhost:4173"]
@ -159,6 +166,13 @@ STOCKNEAR_API_KEY = os.getenv('STOCKNEAR_API_KEY')
api_key_header = APIKeyHeader(name="X-API-KEY")
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != STOCKNEAR_API_KEY:
raise HTTPException(status_code=403, detail="Could not validate credentials")
@ -177,6 +191,12 @@ async def openapi(username: str = Depends(get_current_username), api_key: str =
class TickerData(BaseModel):
ticker: str
class OptionsFlowData(BaseModel):
ticker: str = ''
start_date: str = ''
end_date: str = ''
pagesize: int = Field(default=1000)
page: int = Field(default=0)
class HistoricalPrice(BaseModel):
ticker: str
@ -2473,9 +2493,14 @@ async def get_options_plot_ticker(data:TickerData, api_key: str = Security(get_a
#api endpoint not for website but for user
@app.post("/raw-options-flow-ticker")
async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_api_key)):
@limiter.limit("100/minute")
async def get_raw_options_flow_ticker(data:OptionsFlowData, request: Request, api_key: str = Security(get_api_key)):
ticker = data.ticker.upper()
cache_key = f"options-flow-{ticker}"
start_date = data.start_date
end_date = data.end_date
pagesize = data.pagesize
page = data.page
cache_key = f"raw-options-flow-{ticker}-{start_date}-{end_date}-{pagesize}-{page}"
cached_result = redis_client.get(cache_key)
if cached_result:
@ -2484,7 +2509,7 @@ async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_a
media_type="application/json",
headers={"Content-Encoding": "gzip"})
try:
data = fin.options_activity(company_tickers=ticker, pagesize=500)
data = fin.options_activity(company_tickers=ticker, date_from=start_date, date_to = end_date, page=page, pagesize=pagesize)
data = orjson.loads(fin.output(data))['option_activity']
except:
data = []