diff --git a/app/cron_industry.py b/app/cron_industry.py index 7e908fa..c0de08a 100644 --- a/app/cron_industry.py +++ b/app/cron_industry.py @@ -7,7 +7,7 @@ from tqdm import tqdm import orjson from GetStartEndDate import GetStartEndDate from collections import defaultdict - +import re import os from dotenv import load_dotenv load_dotenv() @@ -27,6 +27,16 @@ def save_as_json(data, filename): with open(f"json/industry/{filename}.json", 'w') as file: ujson.dump(data, file) +def format_filename(industry_name): + # Replace spaces with hyphens + formatted_name = industry_name.replace(' ', '-') + # Replace "&" with "and" + formatted_name = formatted_name.replace('&', 'and') + # Remove any extra hyphens (e.g., from consecutive spaces) + formatted_name = re.sub(r'-+', '-', formatted_name) + # Convert to lowercase for consistency + return formatted_name.lower() + # Function to fetch data from the API async def get_data(session, class_type='sector'): @@ -38,8 +48,34 @@ async def get_data(session, class_type='sector'): data = await response.json() return data +def get_each_industry_data(): + industry_data = defaultdict(list) # Dictionary to store industries and their corresponding stock data + for stock in stock_screener_data: + industry = stock.get('industry') + if industry: # Make sure the stock has an industry defined + # Extract relevant fields + stock_data = { + 'symbol': stock.get('symbol'), + 'name': stock.get('name'), + 'changesPercentage': stock.get('changesPercentage'), + 'marketCap': stock.get('marketCap'), + 'revenue': stock.get('revenue'), + 'volume': stock.get('volume') + } + # Append stock data to the corresponding industry list + industry_data[industry].append(stock_data) + + return dict(industry_data) async def run(): + + full_industry_list = get_each_industry_data() + for industry, stocks in full_industry_list.items(): + filename = 'industries/'+format_filename(industry) + stocks = sorted(stocks, key= lambda x: x['marketCap'], reverse=True) + save_as_json(stocks, filename) + + # Initialize a dictionary to store stock count, market cap, and other totals for each industry sector_industry_data = defaultdict(lambda: defaultdict(lambda: { 'numStocks': 0, diff --git a/app/main.py b/app/main.py index 2c22780..aa8bca0 100755 --- a/app/main.py +++ b/app/main.py @@ -3663,6 +3663,37 @@ async def get_sector_overview(api_key: str = Security(get_api_key)): headers={"Content-Encoding": "gzip"} ) + +@app.post("/industry-stocks") +async def get_sector_overview(data: TickerData, api_key: str = Security(get_api_key)): + ticker = data.ticker + cache_key = f"industry-stocks-{ticker}" + cached_result = redis_client.get(cache_key) + if cached_result: + return StreamingResponse( + io.BytesIO(cached_result), + media_type="application/json", + headers={"Content-Encoding": "gzip"} + ) + try: + with open(f"json/industry/industries/{ticker}.json", 'rb') as file: + res = orjson.loads(file.read()) + except: + res = [] + + data = orjson.dumps(res) + compressed_data = gzip.compress(data) + + redis_client.set(cache_key, compressed_data) + redis_client.expire(cache_key,3600*3600) + + return StreamingResponse( + io.BytesIO(compressed_data), + media_type="application/json", + headers={"Content-Encoding": "gzip"} + ) + + @app.get("/industry-overview") async def get_industry_overview(api_key: str = Security(get_api_key)): cache_key = f"industry-overview"