add limiter to endpoint

This commit is contained in:
MuslemRahimi 2024-07-19 09:21:56 +02:00
parent abbe1b57d0
commit ece28efd52

View File

@ -15,7 +15,7 @@ import aiohttp
import pytz import pytz
import redis import redis
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel from pydantic import BaseModel, Field
from benzinga import financial_data from benzinga import financial_data
# Database related imports # Database related imports
@ -24,13 +24,17 @@ from contextlib import contextmanager
from pocketbase import PocketBase from pocketbase import PocketBase
# FastAPI and related imports # FastAPI and related imports
from fastapi import FastAPI, Depends, HTTPException, Security, status from fastapi import FastAPI, Depends, HTTPException, Security, Request, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security.api_key import APIKeyHeader from fastapi.security.api_key import APIKeyHeader
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# DB constants & context manager # DB constants & context manager
@ -123,7 +127,10 @@ api_key = os.getenv('FMP_API_KEY')
benzinga_key = os.getenv('BENZINGA_API_KEY') benzinga_key = os.getenv('BENZINGA_API_KEY')
fin = financial_data.Benzinga(benzinga_key) fin = financial_data.Benzinga(benzinga_key)
app = FastAPI(docs_url=None, redoc_url=None, openapi_url = None) app = FastAPI(docs_url=None, redoc_url=None, openapi_url = None)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
origins = ["http://www.stocknear.com","https://www.stocknear.com","http://stocknear.com","https://stocknear.com","http://localhost:5173","http://localhost:4173"] origins = ["http://www.stocknear.com","https://www.stocknear.com","http://stocknear.com","https://stocknear.com","http://localhost:5173","http://localhost:4173"]
@ -159,6 +166,13 @@ STOCKNEAR_API_KEY = os.getenv('STOCKNEAR_API_KEY')
api_key_header = APIKeyHeader(name="X-API-KEY") api_key_header = APIKeyHeader(name="X-API-KEY")
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
async def get_api_key(api_key: str = Security(api_key_header)): async def get_api_key(api_key: str = Security(api_key_header)):
if api_key != STOCKNEAR_API_KEY: if api_key != STOCKNEAR_API_KEY:
raise HTTPException(status_code=403, detail="Could not validate credentials") raise HTTPException(status_code=403, detail="Could not validate credentials")
@ -177,6 +191,12 @@ async def openapi(username: str = Depends(get_current_username), api_key: str =
class TickerData(BaseModel): class TickerData(BaseModel):
ticker: str ticker: str
class OptionsFlowData(BaseModel):
ticker: str = ''
start_date: str = ''
end_date: str = ''
pagesize: int = Field(default=1000)
page: int = Field(default=0)
class HistoricalPrice(BaseModel): class HistoricalPrice(BaseModel):
ticker: str ticker: str
@ -2473,9 +2493,14 @@ async def get_options_plot_ticker(data:TickerData, api_key: str = Security(get_a
#api endpoint not for website but for user #api endpoint not for website but for user
@app.post("/raw-options-flow-ticker") @app.post("/raw-options-flow-ticker")
async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_api_key)): @limiter.limit("100/minute")
async def get_raw_options_flow_ticker(data:OptionsFlowData, request: Request, api_key: str = Security(get_api_key)):
ticker = data.ticker.upper() ticker = data.ticker.upper()
cache_key = f"options-flow-{ticker}" start_date = data.start_date
end_date = data.end_date
pagesize = data.pagesize
page = data.page
cache_key = f"raw-options-flow-{ticker}-{start_date}-{end_date}-{pagesize}-{page}"
cached_result = redis_client.get(cache_key) cached_result = redis_client.get(cache_key)
if cached_result: if cached_result:
@ -2484,7 +2509,7 @@ async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_a
media_type="application/json", media_type="application/json",
headers={"Content-Encoding": "gzip"}) headers={"Content-Encoding": "gzip"})
try: try:
data = fin.options_activity(company_tickers=ticker, pagesize=500) data = fin.options_activity(company_tickers=ticker, date_from=start_date, date_to = end_date, page=page, pagesize=pagesize)
data = orjson.loads(fin.output(data))['option_activity'] data = orjson.loads(fin.output(data))['option_activity']
except: except:
data = [] data = []