add random_state to model

This commit is contained in:
MuslemRahimi 2024-10-11 15:12:54 +02:00
parent c936d71e10
commit f00b4288e9
2 changed files with 50 additions and 10 deletions

View File

@ -299,6 +299,10 @@ class HeatMapData(BaseModel):
class StockScreenerData(BaseModel): class StockScreenerData(BaseModel):
ruleOfList: List[str] ruleOfList: List[str]
class IndicatorListData(BaseModel):
ruleOfList: List[str]
tickerList: List[str]
class TransactionId(BaseModel): class TransactionId(BaseModel):
transactionId: str transactionId: str
@ -1097,14 +1101,38 @@ async def get_analyst_ticke_history(data: TickerData, api_key: str = Security(ge
) )
@app.post("/indicator-data")
async def get_indicator_data(data: IndicatorListData, api_key: str = Security(get_api_key)):
rule_of_list = data.ruleOfList
ticker_list = data.tickerList
always_include = ['symbol','name','price','changesPercentage']
try:
filtered_data = [
{key: item.get(key) for key in set(always_include+rule_of_list) if key in item}
for item in stock_screener_data
if item.get('symbol') in ticker_list # Filter for specific tickers
]
except Exception as e:
print(e)
filtered_data = []
# Compress the JSON data
res = orjson.dumps(filtered_data)
compressed_data = gzip.compress(res)
return StreamingResponse(
io.BytesIO(compressed_data),
media_type="application/json",
headers={"Content-Encoding": "gzip"}
)
@app.post("/get-watchlist") @app.post("/get-watchlist")
async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)): async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)):
data = data.dict() data = data.dict()
watchlist_id = data['watchListId'] watchlist_id = data['watchListId']
result = pb.collection("watchlist").get_one(watchlist_id) result = pb.collection("watchlist").get_one(watchlist_id)
ticker_list = result.ticker ticker_list = result.ticker
rule_of_list = result.rule_of_list or []
print(rule_of_list)
combined_results = [] # List to store the combined results combined_results = [] # List to store the combined results
combined_news = [] combined_news = []
@ -1115,6 +1143,8 @@ async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)
except FileNotFoundError: except FileNotFoundError:
return None return None
quote_keys_to_include = ['volume', 'marketCap', 'changesPercentage', 'price', 'symbol', 'name']
# Categorize tickers and fetch data # Categorize tickers and fetch data
for ticker in map(str.upper, ticker_list): for ticker in map(str.upper, ticker_list):
ticker_type = 'stock' ticker_type = 'stock'
@ -1123,11 +1153,12 @@ async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)
elif ticker in crypto_symbols: elif ticker in crypto_symbols:
ticker_type = 'crypto' ticker_type = 'crypto'
# Load quote data # Load quote data and filter to include only selected keys
quote_dict = load_json(f"json/quote/{ticker}.json") quote_dict = load_json(f"json/quote/{ticker}.json")
if quote_dict: if quote_dict:
quote_dict['type'] = ticker_type filtered_quote = {key: quote_dict.get(key) for key in quote_keys_to_include}
combined_results.append(quote_dict) filtered_quote['type'] = ticker_type # Include ticker type
combined_results.append(filtered_quote)
# Load news data # Load news data
news_dict = load_json(f"json/market-news/companies/{ticker}.json") news_dict = load_json(f"json/market-news/companies/{ticker}.json")
@ -1135,12 +1166,12 @@ async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)
combined_news.append(news_dict[0]) combined_news.append(news_dict[0])
# Keys to always include in the combined results # Keys to always include in the combined results
always_include = ['symbol', 'name', 'price', 'changesPercentage'] always_include = ['symbol', 'name', 'price', 'changesPercentage', 'marketCap','eps','pe','score','revenue','netIncome','freeCashFlow']
try: try:
# Create a mapping of stock_screener_data based on symbol for fast lookup # Create a mapping of stock_screener_data based on symbol for fast lookup
screener_dict = { screener_dict = {
item['symbol']: {key: item.get(key) for key in set(always_include + rule_of_list) if key in item} item['symbol']: {key: item.get(key) for key in set(always_include) if key in item}
for item in stock_screener_data for item in stock_screener_data
} }
@ -1152,8 +1183,16 @@ async def get_watchlist(data: GetWatchList, api_key: str = Security(get_api_key)
except Exception as e: except Exception as e:
print(f"An error occurred while merging data: {e}") print(f"An error occurred while merging data: {e}")
print(combined_results)
return [combined_results, combined_news] res = {'data': combined_results, 'news': combined_news}
res = orjson.dumps(res)
compressed_data = gzip.compress(res)
return StreamingResponse(
io.BytesIO(compressed_data),
media_type="application/json",
headers={"Content-Encoding": "gzip"}
)
def process_option_activity(item): def process_option_activity(item):

View File

@ -24,7 +24,8 @@ class ScorePredictor:
learning_rate=0.001, learning_rate=0.001,
max_depth=10, max_depth=10,
num_leaves=2**10-1, num_leaves=2**10-1,
n_jobs=10 n_jobs=10,
random_state=42
) )
self.warm_start_model_path = 'ml_models/weights/ai-score/stacking_weights.pkl' self.warm_start_model_path = 'ml_models/weights/ai-score/stacking_weights.pkl'
#self.pca = PCA(n_components=3) #self.pca = PCA(n_components=3)