diff --git a/app/main.py b/app/main.py index 52c90a3..f4e7e2b 100755 --- a/app/main.py +++ b/app/main.py @@ -63,33 +63,27 @@ def calculate_score(item: Dict, search_query: str) -> int: name_lower = item['name'].lower() symbol_lower = item['symbol'].lower() query_lower = search_query.lower() - - # Exact symbol match + + # Base priority calculations if symbol_lower == query_lower: - return PRIORITY_STRATEGIES['exact_symbol_match'] - - # Symbol prefix match - if symbol_lower.startswith(query_lower): - return PRIORITY_STRATEGIES['symbol_prefix_match'] - - # Exact name match - if name_lower == query_lower: - return PRIORITY_STRATEGIES['exact_name_match'] - - # Name prefix match - if name_lower.startswith(query_lower): - return PRIORITY_STRATEGIES['name_prefix_match'] - - # Symbol contains query - if query_lower in symbol_lower: - return PRIORITY_STRATEGIES['symbol_contains'] - - # Name contains query - if query_lower in name_lower: - return PRIORITY_STRATEGIES['name_contains'] - - # Fallback - return len(PRIORITY_STRATEGIES) + base_score = PRIORITY_STRATEGIES['exact_symbol_match'] + elif symbol_lower.startswith(query_lower): + base_score = PRIORITY_STRATEGIES['symbol_prefix_match'] + elif name_lower == query_lower: + base_score = PRIORITY_STRATEGIES['exact_name_match'] + elif name_lower.startswith(query_lower): + base_score = PRIORITY_STRATEGIES['name_prefix_match'] + elif query_lower in symbol_lower: + base_score = PRIORITY_STRATEGIES['symbol_contains'] + elif query_lower in name_lower: + base_score = PRIORITY_STRATEGIES['name_contains'] + else: + base_score = len(PRIORITY_STRATEGIES) + + # Apply penalty if the symbol contains a dot + dot_penalty = 1 if '.' in symbol_lower else 0 + + return base_score + dot_penalty @contextmanager @@ -116,13 +110,14 @@ with db_connection(STOCK_DB) as cursor: cursor.execute("SELECT DISTINCT symbol FROM stocks") symbols = [row[0] for row in cursor.fetchall()] - cursor.execute("SELECT symbol, name, type FROM stocks") + cursor.execute("SELECT symbol, name, type, marketCap FROM stocks") raw_data = cursor.fetchall() stock_list_data = [{ 'symbol': row[0], 'name': row[1], 'type': row[2].capitalize(), - } for row in raw_data] + 'marketCap': row[3], + } for row in raw_data if row[3] is not None] #------End Stocks DB------------# #------Start ETF DB------------# @@ -1804,17 +1799,24 @@ async def get_stock( # Precompile case-insensitive regex for faster matching search_pattern = re.compile(re.escape(query.lower()), re.IGNORECASE) - - # Optimized filtering and sorting - results = sorted( - ( - item for item in searchbar_data - if search_pattern.search(item['name']) or search_pattern.search(item['symbol']) - ), - key=lambda item: calculate_score(item, query) - )[:5] - return JSONResponse(content=orjson.loads(orjson.dumps(results))) + # Filter items based on the search pattern (ignore items where neither name nor symbol match) + filtered_data = [ + item for item in searchbar_data + if search_pattern.search(item['name']) or search_pattern.search(item['symbol']) + ] + + # Sort by the calculated score, giving exact symbol matches the highest priority, + # and then by descending marketCap for other matches + results = sorted( + filtered_data, + key=lambda item: ( + calculate_score(item, query), + 0 if item.get('marketCap') is None else -item['marketCap'] + ) + )[:5] + + return JSONResponse(content=orjson.loads(orjson.dumps(results)))