diff --git a/app/main.py b/app/main.py index a84a4df..c3f0c40 100755 --- a/app/main.py +++ b/app/main.py @@ -15,7 +15,7 @@ import aiohttp import pytz import redis from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field from benzinga import financial_data # Database related imports @@ -24,13 +24,17 @@ from contextlib import contextmanager from pocketbase import PocketBase # FastAPI and related imports -from fastapi import FastAPI, Depends, HTTPException, Security, status +from fastapi import FastAPI, Depends, HTTPException, Security, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.docs import get_swagger_ui_html from fastapi.openapi.utils import get_openapi from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security.api_key import APIKeyHeader -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse + +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded # DB constants & context manager @@ -123,7 +127,10 @@ api_key = os.getenv('FMP_API_KEY') benzinga_key = os.getenv('BENZINGA_API_KEY') fin = financial_data.Benzinga(benzinga_key) + app = FastAPI(docs_url=None, redoc_url=None, openapi_url = None) +limiter = Limiter(key_func=get_remote_address) +app.state.limiter = limiter origins = ["http://www.stocknear.com","https://www.stocknear.com","http://stocknear.com","https://stocknear.com","http://localhost:5173","http://localhost:4173"] @@ -159,6 +166,13 @@ STOCKNEAR_API_KEY = os.getenv('STOCKNEAR_API_KEY') api_key_header = APIKeyHeader(name="X-API-KEY") +@app.exception_handler(RateLimitExceeded) +async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=429, + content={"detail": "Rate limit exceeded"} + ) + async def get_api_key(api_key: str = Security(api_key_header)): if api_key != STOCKNEAR_API_KEY: raise HTTPException(status_code=403, detail="Could not validate credentials") @@ -177,6 +191,12 @@ async def openapi(username: str = Depends(get_current_username), api_key: str = class TickerData(BaseModel): ticker: str +class OptionsFlowData(BaseModel): + ticker: str = '' + start_date: str = '' + end_date: str = '' + pagesize: int = Field(default=1000) + page: int = Field(default=0) class HistoricalPrice(BaseModel): ticker: str @@ -2473,9 +2493,14 @@ async def get_options_plot_ticker(data:TickerData, api_key: str = Security(get_a #api endpoint not for website but for user @app.post("/raw-options-flow-ticker") -async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_api_key)): +@limiter.limit("100/minute") +async def get_raw_options_flow_ticker(data:OptionsFlowData, request: Request, api_key: str = Security(get_api_key)): ticker = data.ticker.upper() - cache_key = f"options-flow-{ticker}" + start_date = data.start_date + end_date = data.end_date + pagesize = data.pagesize + page = data.page + cache_key = f"raw-options-flow-{ticker}-{start_date}-{end_date}-{pagesize}-{page}" cached_result = redis_client.get(cache_key) if cached_result: @@ -2484,7 +2509,7 @@ async def get_options_flow_ticker(data:TickerData, api_key: str = Security(get_a media_type="application/json", headers={"Content-Encoding": "gzip"}) try: - data = fin.options_activity(company_tickers=ticker, pagesize=500) + data = fin.options_activity(company_tickers=ticker, date_from=start_date, date_to = end_date, page=page, pagesize=pagesize) data = orjson.loads(fin.output(data))['option_activity'] except: data = []