BuckLakeAI / us_stock.py
parkerjj's picture
Daily Update, First Release for model 1012
62f31c8
import logging
import re
import akshare as ak
import pandas as pd
from datetime import datetime, timedelta
import time # 导入标准库的 time 模块
import os
import requests
import threading
import asyncio
logging.basicConfig(level=logging.INFO)
# 获取当前文件的目录
base_dir = os.path.dirname(os.path.abspath(__file__))
# 构建CSV文件的绝对路径
nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv')
dow_jones_path = os.path.join(base_dir, './model/dji.csv')
sp500_path = os.path.join(base_dir, './model/sp500.csv')
nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv')
# 从CSV文件加载成分股数据
nasdaq_100_stocks = pd.read_csv(nasdaq_100_path)
dow_jones_stocks = pd.read_csv(dow_jones_path)
sp500_stocks = pd.read_csv(sp500_path)
nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path)
def fetch_stock_us_spot_data_with_retries():
# 定义重试间隔时间序列(秒)
retry_intervals = [10, 20, 60, 300, 600]
retry_index = 0 # 初始重试序号
while True:
try:
# 尝试获取API数据
symbols = ak.stock_us_spot_em()
return symbols # 成功获取数据后返回
except Exception as e:
print(f"Error fetching data: {e}")
# 获取当前重试等待时间
wait_time = retry_intervals[retry_index]
print(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time) # 等待指定的秒数
# 更新重试索引,但不要超出重试时间列表的范围
retry_index = min(retry_index + 1, len(retry_intervals) - 1)
async def fetch_stock_us_spot_data_with_retries_async():
retry_intervals = [10, 20, 60, 300, 600]
retry_index = 0
while True:
try:
symbols = await asyncio.to_thread(ak.stock_us_spot_em)
return symbols
except Exception as e:
print(f"Error fetching data: {e}")
wait_time = retry_intervals[retry_index]
print(f"Retrying in {wait_time} seconds...")
await asyncio.sleep(wait_time)
retry_index = min(retry_index + 1, len(retry_intervals) - 1)
symbols = None
async def fetch_symbols():
global symbols
# 异步获取数据
symbols = await fetch_stock_us_spot_data_with_retries_async()
print("Symbols initialized:", symbols)
# 全局变量
index_us_stock_index_INX = None
index_us_stock_index_DJI = None
index_us_stock_index_IXIC = None
index_us_stock_index_NDX = None
def update_stock_indices():
global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX
try:
index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX")
index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI")
index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC")
index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX")
print("Stock indices updated")
except Exception as e:
print(f"Error updating stock indices: {e}")
# 设置定时器,每隔12小时更新一次
threading.Timer(12 * 60 * 60, update_stock_indices).start()
# 程序开始时立即更新一次
update_stock_indices()
# 创建列名转换的字典
column_mapping = {
'日期': 'date',
'开盘': 'open',
'收盘': 'close',
'最高': 'high',
'最低': 'low',
'成交量': 'volume',
'成交额': 'amount',
'振幅': 'amplitude',
'涨跌幅': 'price_change_percentage',
'涨跌额': 'price_change_amount',
'换手率': 'turnover_rate'
}
# 定义一个标准的列顺序
standard_columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount']
# 定义查找函数
def find_stock_entry(stock_code):
# 使用 str.endswith 来匹配股票代码
matching_row = symbols[symbols['代码'].str.endswith(stock_code)]
# print(symbols)
if not matching_row.empty:
# print(f"股票代码 {stock_code} 找到, 代码为 {matching_row['代码'].values[0]}")
return matching_row['代码'].values[0]
else:
return ""
'''
# 示例调用
# 测试函数
result = find_stock_entry('AAPL')
if isinstance(result, pd.DataFrame) and not result.empty:
# 如果找到的结果不为空,获取代码列的值
code_value = result['代码'].values[0]
print(code_value)
else:
print(result)
'''
def reduce_columns(df, columns_to_keep):
return df[columns_to_keep]
# 返回个股历史数据
def get_stock_history(symbol, news_date, retries=10):
# 定义重试间隔时间序列(秒)
retry_intervals = [10, 20, 60, 300, 600]
retry_count = 0
# 如果传入的symbol不包含数字前缀,则通过 find_stock_entry 获取完整的symbol
if not any(char.isdigit() for char in symbol):
full_symbol = find_stock_entry(symbol)
if len(symbol) != 0 and full_symbol:
symbol = full_symbol
else:
symbol = ""
# 将news_date转换为datetime对象
news_date_dt = datetime.strptime(news_date, "%Y%m%d")
# 计算start_date和end_date
start_date = (news_date_dt - timedelta(weeks=2)).strftime("%Y%m%d")
end_date = (news_date_dt + timedelta(weeks=2)).strftime("%Y%m%d")
stock_hist_df = None
retry_index = 0 # 初始化重试索引
while retry_count <= retries and len(symbol) != 0: # 无限循环重试
try:
# 尝试获取API数据
stock_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust="")
if stock_hist_df.empty: # 检查是否为空数据
# print(f"No data for {symbol} on {news_date}.")
stock_hist_df = None # 将 DataFrame 设置为 None
break
except (requests.exceptions.Timeout, ConnectionError) as e:
print(f"Request timed out: {e}. Retrying...")
retry_count += 1 # 增加重试次数
continue
except (TypeError, ValueError, BaseException) as e:
print(f"Error {e} scraping data for {symbol} on {news_date}. Break...")
# 可能是没数据,直接Break
break
# 如果发生异常,等待一段时间再重试
wait_time = retry_intervals[retry_index]
print(f"Waiting for {wait_time} seconds before retrying...")
time.sleep(wait_time)
retry_index = (retry_index + 1) if retry_index < len(retry_intervals) - 1 else retry_index # 更新重试索引,不超过列表长度
# 如果获取失败或数据为空,返回填充为0的 DataFrame
if stock_hist_df is None or stock_hist_df.empty:
# 构建一个空的 DataFrame,包含指定日期范围的空数据
date_range = pd.date_range(start=start_date, end=end_date)
stock_hist_df = pd.DataFrame({
'date': date_range,
'开盘': 0,
'收盘': 0,
'最高': 0,
'最低': 0,
'成交量': 0,
'成交额': 0,
'振幅': 0,
'涨跌幅': 0,
'涨跌额': 0,
'换手率': 0
})
# 使用rename方法转换列名
stock_hist_df = stock_hist_df.rename(columns=column_mapping)
stock_hist_df = stock_hist_df.reindex(columns=standard_columns)
# 处理个股数据,保留所需列
stock_hist_df = reduce_columns(stock_hist_df, standard_columns)
return stock_hist_df
# 统一列名
stock_hist_df = stock_hist_df.rename(columns=column_mapping)
stock_hist_df = stock_hist_df.reindex(columns=standard_columns)
# 处理个股数据,保留所需列
stock_hist_df = reduce_columns(stock_hist_df, standard_columns)
return stock_hist_df
'''
# 示例调用
result = get_stock_history('AAPL', '20240214')
print(result)
'''
# result = get_stock_history('ATMU', '20231218')
# print(result)
# 返回个股所属指数历史数据
def get_stock_index_history(symbol, news_date, force_index=0):
# 检查股票所属的指数
if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1:
index_code = ".NDX"
index_data = index_us_stock_index_NDX
elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2:
index_code = ".DJI"
index_data = index_us_stock_index_DJI
elif symbol in sp500_stocks['Symbol'].values or force_index == 3:
index_code = ".INX"
index_data = index_us_stock_index_INX
elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4:
index_code = ".IXIC"
index_data = index_us_stock_index_IXIC
else:
# print(f"股票代码 {symbol} 不属于纳斯达克100、道琼斯工业、标准普尔500或纳斯达克综合指数。")
index_code = ".IXIC"
index_data = index_us_stock_index_IXIC
# 将 news_date 转换为 datetime 对象
news_date_dt = datetime.strptime(news_date, "%Y%m%d")
# 计算 start_date 和 end_date
start_date = (news_date_dt - timedelta(weeks=8)).strftime("%Y-%m-%d")
end_date = (news_date_dt + timedelta(weeks=2)).strftime("%Y-%m-%d")
# 确保 index_data['date'] 是 datetime 类型
index_data['date'] = pd.to_datetime(index_data['date'])
# 从指数历史数据中提取指定日期范围的数据
index_hist_df = index_data[(index_data['date'] >= start_date) & (index_data['date'] <= end_date)]
# 统一列名
index_hist_df = index_hist_df.rename(columns=column_mapping)
index_hist_df = index_hist_df.reindex(columns=standard_columns)
# 处理个股数据,保留所需列
index_hist_df = reduce_columns(index_hist_df, standard_columns)
return index_hist_df
'''
# 示例调用
result = get_stock_index_history('AAPL', '20240214')
print(result)
'''
def find_stock_codes_or_names(entities):
"""
从给定的实体列表中检索股票代码或公司名称。
:param entities: 命名实体识别结果列表,格式为 [('实体名称', '实体类型'), ...]
:return: 相关的股票代码列表
"""
stock_codes = set()
# 合并所有股票字典并清理数据,确保都是字符串
all_symbols = pd.concat([nasdaq_100_stocks['Symbol'],
dow_jones_stocks['Symbol'],
sp500_stocks['Symbol'],
nasdaq_composite_stocks['Symbol']]).dropna().astype(str).unique().tolist()
all_names = pd.concat([nasdaq_100_stocks['Name'],
nasdaq_composite_stocks['Name'],
sp500_stocks['Security'],
dow_jones_stocks['Company']]).dropna().astype(str).unique().tolist()
# 创建一个 Name 到 Symbol 的映射
name_to_symbol = {}
for idx, name in enumerate(all_names):
if idx < len(all_symbols):
symbol = all_symbols[idx]
name_to_symbol[name.lower()] = symbol
# 查找实体映射到的股票代码
for entity, entity_type in entities:
entity_lower = entity.lower()
entity_upper = entity.upper()
# 检查 Symbol 列
if entity_upper in all_symbols:
stock_codes.add(entity_upper)
print(f"Matched symbol: {entity_upper}")
# 检查 Name 列,确保完整匹配而不是部分匹配
for name, symbol in name_to_symbol.items():
# 使用正则表达式进行严格匹配
pattern = rf'\b{re.escape(entity_lower)}\b'
if re.search(pattern, name):
stock_codes.add(symbol.upper())
print(f"Matched name/company: '{entity_lower}' in '{name}' -> {symbol.upper()}")
print(f"Stock codes found: {stock_codes}")
return list(stock_codes)