diff --git a/app/main.py b/app/main.py index 881f5e0..45a5297 100755 --- a/app/main.py +++ b/app/main.py @@ -37,6 +37,7 @@ from fastapi.responses import StreamingResponse, JSONResponse from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded +from functools import partial # DB constants & context manager @@ -1255,72 +1256,116 @@ async def get_indicator(data: IndicatorListData, api_key: str = Security(get_api +async def process_watchlist_ticker(ticker, rule_of_list, quote_keys_to_include, screener_dict, etf_symbols, crypto_symbols): + """Process a single ticker concurrently.""" + ticker = ticker.upper() + ticker_type = 'stocks' + if ticker in etf_symbols: + ticker_type = 'etf' + elif ticker in crypto_symbols: + ticker_type = 'crypto' + + # Concurrent loading of quote, news, and earnings data + quote_task = load_json_async(f"json/quote/{ticker}.json") + news_task = load_json_async(f"json/market-news/companies/{ticker}.json") + earnings_task = load_json_async(f"json/earnings/next/{ticker}.json") + + quote_dict, news_dict, earnings_dict = await asyncio.gather(quote_task, news_task, earnings_task) + + result = None + news = [] + earnings = None + + if quote_dict: + # Filter quote data + filtered_quote = { + key: quote_dict.get(key) + for key in rule_of_list if key in quote_dict or key in quote_keys_to_include + } + filtered_quote['type'] = ticker_type + + # Merge with screener data + symbol = filtered_quote.get('symbol') + if symbol and symbol in screener_dict: + filtered_quote.update(screener_dict[symbol]) + + result = filtered_quote + + if news_dict: + # Remove 'image' and 'text' keys from each news item + news = [ + {key: value for key, value in item.items() if key not in ['image', 'text']} + for item in news_dict[:5] + ] + + # Prepare earnings with symbol + if earnings_dict and symbol: + earnings = {**earnings_dict, 'symbol': symbol} + + return result, news, earnings + + @app.post("/get-watchlist") async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)): - data = data.dict() - watchlist_id = data['watchListId'] - rule_of_list = data['ruleOfList'] # Ensure this is passed as part of the request - result = pb.collection("watchlist").get_one(watchlist_id) - ticker_list = result.ticker - combined_results = [] # List to store the combined results - combined_news = [] - - # Keys that should be read from the quote files if they are in rule_of_list - quote_keys_to_include = ['volume', 'marketCap', 'changesPercentage', 'price', 'symbol', 'name'] - - # Ensure rule_of_list contains valid keys (fall back to defaults if necessary) - if not rule_of_list or not isinstance(rule_of_list, list): - rule_of_list = quote_keys_to_include # Default keys - - # Make sure 'symbol' and 'name' are always included in the rule_of_list - if 'symbol' not in rule_of_list: - rule_of_list.append('symbol') - if 'name' not in rule_of_list: - rule_of_list.append('name') - - # Categorize tickers and fetch data - for ticker in map(str.upper, ticker_list): - ticker_type = 'stocks' - if ticker in etf_symbols: - ticker_type = 'etf' - elif ticker in crypto_symbols: - ticker_type = 'crypto' - - # Load quote data and filter to include only selected keys from rule_of_list - quote_dict = load_json(f"json/quote/{ticker}.json") - if quote_dict: - filtered_quote = {key: quote_dict.get(key) for key in rule_of_list if key in quote_dict} - filtered_quote['type'] = ticker_type # Include ticker type - combined_results.append(filtered_quote) - - # Load news data - news_dict = load_json(f"json/market-news/companies/{ticker}.json") - if news_dict: - combined_news.extend(news_dict[:2]) - + """Optimized watchlist endpoint with concurrent processing and earnings data.""" + data_dict = data.dict() + watchlist_id = data_dict['watchListId'] + rule_of_list = data_dict.get('ruleOfList', []) + # Retrieve watchlist try: - # Filter out the keys that need to be fetched from the screener - screener_keys = [key for key in rule_of_list if key not in quote_keys_to_include] - - # Create a mapping of stock_screener_data based on symbol for fast lookup - screener_dict = { - item['symbol']: {key: item.get(key) for key in screener_keys if key in item} - for item in stock_screener_data - } - - # Merge the filtered stock_screener_data into combined_results for non-quote keys - for result in combined_results: - symbol = result.get('symbol') - if symbol in screener_dict: - result.update(screener_dict[symbol]) - + result = pb.collection("watchlist").get_one(watchlist_id) + ticker_list = result.ticker except Exception as e: - print(f"An error occurred while merging data: {e}") + raise HTTPException(status_code=404, detail="Watchlist not found") - res = {'data': combined_results, 'news': combined_news} - res = orjson.dumps(res) - compressed_data = gzip.compress(res) + # Default configuration + quote_keys_to_include = ['volume', 'marketCap', 'changesPercentage', 'price', 'symbol', 'name'] + + # Normalize rule_of_list + if not rule_of_list or not isinstance(rule_of_list, list): + rule_of_list = quote_keys_to_include + + rule_of_list = list(set(rule_of_list + ['symbol', 'name'])) + + # Prepare screener dictionary for fast lookup + screener_dict = { + item['symbol']: { + key: item.get(key) + for key in rule_of_list if key in item + } + for item in stock_screener_data + } + + # Process tickers concurrently + process_ticker_partial = partial( + process_watchlist_ticker, + rule_of_list=rule_of_list, + quote_keys_to_include=quote_keys_to_include, + screener_dict=screener_dict, + etf_symbols=etf_symbols, + crypto_symbols=crypto_symbols + ) + + # Use asyncio to process tickers in parallel + results_and_extras = await asyncio.gather( + *[process_ticker_partial(ticker) for ticker in ticker_list] + ) + + # Separate results, news, and earnings + combined_results = [result for result, _, _ in results_and_extras if result] + combined_news = [news_item for _, news, _ in results_and_extras for news_item in news] + combined_earnings = [earnings for _, _, earnings in results_and_extras if earnings] + + # Prepare response + res = { + 'data': combined_results, + 'news': combined_news, + 'earnings': combined_earnings + } + + print(combined_earnings) + compressed_data = gzip.compress(orjson.dumps(res)) return StreamingResponse( io.BytesIO(compressed_data), @@ -1328,6 +1373,9 @@ async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key) headers={"Content-Encoding": "gzip"} ) + + + @app.post("/get-price-alert") async def get_price_alert(data: UserId, api_key: str = Security(get_api_key)): user_id = data.dict()['userId']