clean code
This commit is contained in:
parent
266016e46d
commit
7100791302
84
app/cron_correlation_etf.py
Normal file
84
app/cron_correlation_etf.py
Normal file
@ -0,0 +1,84 @@
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
import concurrent.futures
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
|
||||
|
||||
def get_stock_prices(ticker, cursor, start_date, end_date):
|
||||
query = f"""
|
||||
SELECT date, close, volume
|
||||
FROM "{ticker}"
|
||||
WHERE date BETWEEN ? AND ?
|
||||
"""
|
||||
cursor.execute(query, (start_date, end_date))
|
||||
return pd.DataFrame(cursor.fetchall(), columns=['date', 'close', 'volume'])
|
||||
|
||||
def process_symbol(op_symbol, symbols, start_date, end_date, query_fundamental):
|
||||
with sqlite3.connect('etf.db') as con:
|
||||
con.execute("PRAGMA journal_mode = WAL")
|
||||
cursor = con.cursor()
|
||||
|
||||
op_df = get_stock_prices(op_symbol, cursor, start_date, end_date)
|
||||
avg_volume = op_df['volume'].mean() * 0.5
|
||||
correlations = {}
|
||||
|
||||
for symbol in symbols:
|
||||
if symbol != op_symbol:
|
||||
try:
|
||||
stock_df = get_stock_prices(symbol, cursor, start_date, end_date)
|
||||
if stock_df['volume'].mean() > avg_volume:
|
||||
correlation = np.corrcoef(op_df['close'], stock_df['close'])[0, 1]
|
||||
correlations[symbol] = correlation
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sorted_correlations = sorted(correlations.items(), key=lambda x: x[1], reverse=True)
|
||||
most_least_list = sorted_correlations[:5] + sorted_correlations[-5:]
|
||||
|
||||
res_list = []
|
||||
for symbol, correlation in most_least_list:
|
||||
cursor.execute(query_fundamental, (symbol,))
|
||||
fundamental_data = cursor.fetchone()
|
||||
if correlation is not None and not np.isnan(correlation):
|
||||
res_list.append({
|
||||
'symbol': symbol,
|
||||
'name': fundamental_data[0],
|
||||
'marketCap': int(fundamental_data[1]),
|
||||
'value': round(correlation, 3)
|
||||
})
|
||||
|
||||
sorted_res = sorted(res_list, key=lambda x: x['value'], reverse=True)
|
||||
res_list = list({d['symbol']: d for d in sorted_res}.values())
|
||||
|
||||
if res_list:
|
||||
os.makedirs("json/correlation/companies", exist_ok=True)
|
||||
with open(f"json/correlation/companies/{op_symbol}.json", 'w') as file:
|
||||
json.dump(res_list, file)
|
||||
|
||||
def main():
|
||||
query_fundamental = "SELECT name, marketCap FROM etfs WHERE symbol = ?"
|
||||
|
||||
with sqlite3.connect('etf.db') as con:
|
||||
con.execute("PRAGMA journal_mode = WAL")
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT DISTINCT symbol FROM etfs")
|
||||
symbols = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
end_date = datetime.today()
|
||||
start_date = end_date - timedelta(days=365) # 12 months
|
||||
|
||||
num_processes = 14 # As specified in your original code
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
|
||||
futures = [executor.submit(process_symbol, symbol, symbols, start_date, end_date, query_fundamental) for symbol in symbols]
|
||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=len(symbols), desc="Processing"):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
app/cron_correlation_stock.py
Normal file
83
app/cron_correlation_stock.py
Normal file
@ -0,0 +1,83 @@
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
import concurrent.futures
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
|
||||
|
||||
def get_stock_prices(ticker, cursor, start_date, end_date):
|
||||
query = f"""
|
||||
SELECT date, close, volume
|
||||
FROM "{ticker}"
|
||||
WHERE date BETWEEN ? AND ?
|
||||
"""
|
||||
cursor.execute(query, (start_date, end_date))
|
||||
return pd.DataFrame(cursor.fetchall(), columns=['date', 'close', 'volume'])
|
||||
|
||||
def process_symbol(op_symbol, symbols, start_date, end_date, query_fundamental):
|
||||
with sqlite3.connect('stocks.db') as con:
|
||||
con.execute("PRAGMA journal_mode = WAL")
|
||||
cursor = con.cursor()
|
||||
|
||||
op_df = get_stock_prices(op_symbol, cursor, start_date, end_date)
|
||||
avg_volume = op_df['volume'].mean() * 0.5
|
||||
correlations = {}
|
||||
|
||||
for symbol in symbols:
|
||||
if symbol != op_symbol:
|
||||
try:
|
||||
stock_df = get_stock_prices(symbol, cursor, start_date, end_date)
|
||||
if stock_df['volume'].mean() > avg_volume:
|
||||
correlation = np.corrcoef(op_df['close'], stock_df['close'])[0, 1]
|
||||
correlations[symbol] = correlation
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sorted_correlations = sorted(correlations.items(), key=lambda x: x[1], reverse=True)
|
||||
most_least_list = sorted_correlations[:5] + sorted_correlations[-5:]
|
||||
|
||||
res_list = []
|
||||
for symbol, correlation in most_least_list:
|
||||
cursor.execute(query_fundamental, (symbol,))
|
||||
fundamental_data = cursor.fetchone()
|
||||
if correlation is not None and not np.isnan(correlation):
|
||||
res_list.append({
|
||||
'symbol': symbol,
|
||||
'name': fundamental_data[0],
|
||||
'marketCap': int(fundamental_data[1]),
|
||||
'value': round(correlation, 3)
|
||||
})
|
||||
|
||||
sorted_res = sorted(res_list, key=lambda x: x['value'], reverse=True)
|
||||
res_list = list({d['symbol']: d for d in sorted_res}.values())
|
||||
|
||||
if res_list:
|
||||
with open(f"correlation/companies/{op_symbol}.json", 'w') as file:
|
||||
json.dump(res_list, file)
|
||||
|
||||
def main():
|
||||
query_fundamental = "SELECT name, marketCap FROM stocks WHERE symbol = ?"
|
||||
|
||||
with sqlite3.connect('stocks.db') as con:
|
||||
con.execute("PRAGMA journal_mode = WAL")
|
||||
cursor = con.cursor()
|
||||
cursor.execute("SELECT DISTINCT symbol FROM stocks")
|
||||
symbols = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
end_date = datetime.today()
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
num_processes = 4
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
|
||||
futures = [executor.submit(process_symbol, symbol, symbols, start_date, end_date, query_fundamental) for symbol in symbols]
|
||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=len(symbols), desc="Processing"):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
app/main.py
23
app/main.py
@ -315,37 +315,26 @@ async def hello_world(api_key: str = Security(get_api_key)):
|
||||
|
||||
|
||||
|
||||
@app.post("/stock-correlation")
|
||||
@app.post("/correlation-ticker")
|
||||
async def rating_stock(data: TickerData, api_key: str = Security(get_api_key)):
|
||||
data = data.dict()
|
||||
ticker = data['ticker'].upper()
|
||||
|
||||
cache_key = f"stock-correlation-{ticker}"
|
||||
cache_key = f"correlation-{ticker}"
|
||||
cached_result = redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
return orjson.loads(cached_result)
|
||||
|
||||
if ticker in etf_symbols:
|
||||
path_name = 'etf'
|
||||
else:
|
||||
path_name = 'stock'
|
||||
|
||||
try:
|
||||
with open(f"json/correlation/{path_name}/{ticker}.json", 'rb') as file:
|
||||
output = orjson.loads(file.read())
|
||||
|
||||
sorted_data = sorted(output, key=lambda x: x['value'], reverse=True)
|
||||
# Remove duplicates based on 'symbol'
|
||||
res = list({d['symbol']: d for d in sorted_data}.values())
|
||||
with open(f"json/correlation/companies/{ticker}.json", 'rb') as file:
|
||||
res = orjson.loads(file.read())
|
||||
except:
|
||||
res = []
|
||||
|
||||
final_res = {'correlation': res, 'type': 'etf' if path_name == 'etf' else 'stocks'}
|
||||
|
||||
redis_client.set(cache_key, orjson.dumps(final_res))
|
||||
redis_client.set(cache_key, orjson.dumps(res))
|
||||
redis_client.expire(cache_key, 3600*24) # Set cache expiration time to 12 hour
|
||||
|
||||
return final_res
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user