You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1845 lines
76 KiB

1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
1 month ago
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import json
  5. import plotly.graph_objects as go
  6. import plotly.utils
  7. from flask import Flask, render_template, request, jsonify
  8. from flask_cors import CORS
  9. import sys
  10. import warnings
  11. from datetime import datetime
  12. import baostock as bs
  13. import re
  14. warnings.filterwarnings('ignore')
  15. # Add project root directory to path
  16. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  17. try:
  18. from model import Kronos, KronosTokenizer, KronosPredictor
  19. MODEL_AVAILABLE = True
  20. except ImportError:
  21. MODEL_AVAILABLE = False
  22. print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
  23. app = Flask(__name__)
  24. CORS(app)
  25. # Global variables to store models
  26. tokenizer = None
  27. model = None
  28. predictor = None
  29. # 获取webui目录的路径
  30. WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
  31. # 获取项目根目录(webui的父目录)
  32. BASE_DIR = os.path.dirname(WEBUI_DIR)
  33. AVAILABLE_MODELS = {
  34. 'kronos-mini': {
  35. 'name': 'Kronos-mini',
  36. 'model_id': os.path.join(BASE_DIR, 'models', 'Kronos-mini'),
  37. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  38. 'context_length': 2048,
  39. 'params': '4.1M',
  40. 'description': '轻量级模型,适合快速预测'
  41. },
  42. 'kronos-small': {
  43. 'name': 'Kronos-small',
  44. 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-small'),
  45. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  46. 'context_length': 512,
  47. 'params': '24.7M',
  48. 'description': '小型模型,平衡性能和速度'
  49. },
  50. 'kronos-base': {
  51. 'name': 'Kronos-base',
  52. 'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-base'),
  53. 'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
  54. 'context_length': 512,
  55. 'params': '102.3M',
  56. 'description': '基础模型,提供更好的预测质量'
  57. }
  58. }
  59. # Available model configurations
  60. # AVAILABLE_MODELS = {
  61. # 'kronos-mini': {
  62. # 'name': 'Kronos-mini',
  63. # 'model_id': 'models/Kronos-mini', # 本地路径
  64. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  65. # 'context_length': 2048,
  66. # 'params': '4.1M',
  67. # 'description': '轻量级模型,适合快速预测'
  68. # },
  69. # 'kronos-small': {
  70. # 'name': 'Kronos-small',
  71. # 'model_id': 'models/NeoQuasarKronos-small', # 本地路径
  72. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  73. # 'context_length': 512,
  74. # 'params': '24.7M',
  75. # 'description': '小型模型,平衡性能和速度'
  76. # },
  77. # 'kronos-base': {
  78. # 'name': 'Kronos-base',
  79. # 'model_id': 'models/NeoQuasarKronos-base', # 本地路径
  80. # 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
  81. # 'context_length': 512,
  82. # 'params': '102.3M',
  83. # 'description': '基础模型,提供更好的预测质量'
  84. # }
  85. # }
  86. def load_data_files():
  87. """Scan data directory and return available data files"""
  88. data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
  89. data_files = []
  90. if os.path.exists(data_dir):
  91. for file in os.listdir(data_dir):
  92. if file.endswith(('.csv', '.feather')):
  93. file_path = os.path.join(data_dir, file)
  94. file_size = os.path.getsize(file_path)
  95. data_files.append({
  96. 'name': file,
  97. 'path': file_path,
  98. 'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
  99. })
  100. return data_files
  101. def load_data_file(file_path):
  102. """Load data file"""
  103. try:
  104. if file_path.endswith('.csv'):
  105. df = pd.read_csv(file_path)
  106. elif file_path.endswith('.feather'):
  107. df = pd.read_feather(file_path)
  108. else:
  109. return None, "Unsupported file format"
  110. # Check required columns
  111. required_cols = ['open', 'high', 'low', 'close']
  112. if not all(col in df.columns for col in required_cols):
  113. return None, f"Missing required columns: {required_cols}"
  114. # Process timestamp column
  115. if 'timestamps' in df.columns:
  116. df['timestamps'] = pd.to_datetime(df['timestamps'])
  117. elif 'timestamp' in df.columns:
  118. df['timestamps'] = pd.to_datetime(df['timestamp'])
  119. elif 'date' in df.columns:
  120. # If column name is 'date', rename it to 'timestamps'
  121. df['timestamps'] = pd.to_datetime(df['date'])
  122. else:
  123. # If no timestamp column exists, create one
  124. df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
  125. # Ensure numeric columns are numeric type
  126. for col in ['open', 'high', 'low', 'close']:
  127. df[col] = pd.to_numeric(df[col], errors='coerce')
  128. # Process volume column (optional)
  129. if 'volume' in df.columns:
  130. df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
  131. # Process amount column (optional, but not used for prediction)
  132. if 'amount' in df.columns:
  133. df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
  134. # Remove rows containing NaN values
  135. df = df.dropna()
  136. return df, None
  137. except Exception as e:
  138. return None, f"Failed to load file: {str(e)}"
  139. def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
  140. """Save prediction results to file"""
  141. try:
  142. # Create prediction results directory
  143. results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
  144. os.makedirs(results_dir, exist_ok=True)
  145. # Generate filename
  146. timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
  147. filename = f'prediction_{timestamp}.json'
  148. filepath = os.path.join(results_dir, filename)
  149. # Prepare data for saving
  150. save_data = {
  151. 'timestamp': datetime.datetime.now().isoformat(),
  152. 'file_path': file_path,
  153. 'prediction_type': prediction_type,
  154. 'prediction_params': prediction_params,
  155. 'input_data_summary': {
  156. 'rows': len(input_data),
  157. 'columns': list(input_data.columns),
  158. 'price_range': {
  159. 'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
  160. 'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
  161. 'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
  162. 'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
  163. },
  164. 'last_values': {
  165. 'open': float(input_data['open'].iloc[-1]),
  166. 'high': float(input_data['high'].iloc[-1]),
  167. 'low': float(input_data['low'].iloc[-1]),
  168. 'close': float(input_data['close'].iloc[-1])
  169. }
  170. },
  171. 'prediction_results': prediction_results,
  172. 'actual_data': actual_data,
  173. 'analysis': {}
  174. }
  175. # If actual data exists, perform comparison analysis
  176. if actual_data and len(actual_data) > 0:
  177. # Calculate continuity analysis
  178. if len(prediction_results) > 0 and len(actual_data) > 0:
  179. last_pred = prediction_results[0] # First prediction point
  180. first_actual = actual_data[0] # First actual point
  181. save_data['analysis']['continuity'] = {
  182. 'last_prediction': {
  183. 'open': last_pred['open'],
  184. 'high': last_pred['high'],
  185. 'low': last_pred['low'],
  186. 'close': last_pred['close']
  187. },
  188. 'first_actual': {
  189. 'open': first_actual['open'],
  190. 'high': first_actual['high'],
  191. 'low': first_actual['low'],
  192. 'close': first_actual['close']
  193. },
  194. 'gaps': {
  195. 'open_gap': abs(last_pred['open'] - first_actual['open']),
  196. 'high_gap': abs(last_pred['high'] - first_actual['high']),
  197. 'low_gap': abs(last_pred['low'] - first_actual['low']),
  198. 'close_gap': abs(last_pred['close'] - first_actual['close'])
  199. },
  200. 'gap_percentages': {
  201. 'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
  202. 'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
  203. 'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
  204. 'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
  205. }
  206. }
  207. # Save to file
  208. with open(filepath, 'w', encoding='utf-8') as f:
  209. json.dump(save_data, f, indent=2, ensure_ascii=False)
  210. print(f"Prediction results saved to: {filepath}")
  211. return filepath
  212. except Exception as e:
  213. print(f"Failed to save prediction results: {e}")
  214. return None
  215. # def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
  216. # """Create prediction chart"""
  217. #
  218. # print(f"🔍 创建图表调试:")
  219. # print(f" 历史数据: {len(df) if df is not None else 0} 行")
  220. # print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
  221. # print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
  222. #
  223. # # 确保数据不为空
  224. # if pred_df is None or len(pred_df) == 0:
  225. # print("⚠️ 警告: 预测数据为空!")
  226. # # 创建空图表
  227. # fig = go.Figure()
  228. # fig.update_layout(title='No prediction data available')
  229. # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  230. #
  231. # # 其余代码保持不变...
  232. #
  233. # # Use specified historical data start position, not always from the beginning of df
  234. # if historical_start_idx + lookback + pred_len <= len(df):
  235. # # Display lookback historical points + pred_len prediction points starting from specified position
  236. # historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
  237. # prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
  238. # else:
  239. # # If data is insufficient, adjust to maximum available range
  240. # available_lookback = min(lookback, len(df) - historical_start_idx)
  241. # available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
  242. # historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
  243. # prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
  244. #
  245. # # Create chart
  246. # fig = go.Figure()
  247. #
  248. # # Add historical data (candlestick chart)
  249. # fig.add_trace(go.Candlestick(
  250. # x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
  251. # open=historical_df['open'],
  252. # high=historical_df['high'],
  253. # low=historical_df['low'],
  254. # close=historical_df['close'],
  255. # name='Historical Data (400 data points)',
  256. # increasing_line_color='#26A69A',
  257. # decreasing_line_color='#EF5350'
  258. # ))
  259. #
  260. # # Add prediction data (candlestick chart)
  261. # if pred_df is not None and len(pred_df) > 0:
  262. # # Calculate prediction data timestamps - ensure continuity with historical data
  263. # if 'timestamps' in df.columns and len(historical_df) > 0:
  264. # # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
  265. # last_timestamp = historical_df['timestamps'].iloc[-1]
  266. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  267. #
  268. # pred_timestamps = pd.date_range(
  269. # start=last_timestamp + time_diff,
  270. # periods=len(pred_df),
  271. # freq=time_diff
  272. # )
  273. # else:
  274. # # If no timestamps, use index
  275. # pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
  276. #
  277. # fig.add_trace(go.Candlestick(
  278. # x=pred_timestamps,
  279. # open=pred_df['open'],
  280. # high=pred_df['high'],
  281. # low=pred_df['low'],
  282. # close=pred_df['close'],
  283. # name='Prediction Data (120 data points)',
  284. # increasing_line_color='#66BB6A',
  285. # decreasing_line_color='#FF7043'
  286. # ))
  287. #
  288. # # Add actual data for comparison (if exists)
  289. # if actual_df is not None and len(actual_df) > 0:
  290. # # Actual data should be in the same time period as prediction data
  291. # if 'timestamps' in df.columns:
  292. # # Actual data should use the same timestamps as prediction data to ensure time alignment
  293. # if 'pred_timestamps' in locals():
  294. # actual_timestamps = pred_timestamps
  295. # else:
  296. # # If no prediction timestamps, calculate from the last timestamp of historical data
  297. # if len(historical_df) > 0:
  298. # last_timestamp = historical_df['timestamps'].iloc[-1]
  299. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  300. # actual_timestamps = pd.date_range(
  301. # start=last_timestamp + time_diff,
  302. # periods=len(actual_df),
  303. # freq=time_diff
  304. # )
  305. # else:
  306. # actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  307. # else:
  308. # actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  309. #
  310. # fig.add_trace(go.Candlestick(
  311. # x=actual_timestamps,
  312. # open=actual_df['open'],
  313. # high=actual_df['high'],
  314. # low=actual_df['low'],
  315. # close=actual_df['close'],
  316. # name='Actual Data (120 data points)',
  317. # increasing_line_color='#FF9800',
  318. # decreasing_line_color='#F44336'
  319. # ))
  320. #
  321. # # Update layout
  322. # fig.update_layout(
  323. # title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
  324. # xaxis_title='Time',
  325. # yaxis_title='Price',
  326. # template='plotly_white',
  327. # height=600,
  328. # showlegend=True
  329. # )
  330. #
  331. # # Ensure x-axis time continuity
  332. # if 'timestamps' in historical_df.columns:
  333. # # Get all timestamps and sort them
  334. # all_timestamps = []
  335. # if len(historical_df) > 0:
  336. # all_timestamps.extend(historical_df['timestamps'])
  337. # if 'pred_timestamps' in locals():
  338. # all_timestamps.extend(pred_timestamps)
  339. # if 'actual_timestamps' in locals():
  340. # all_timestamps.extend(actual_timestamps)
  341. #
  342. # if all_timestamps:
  343. # all_timestamps = sorted(all_timestamps)
  344. # fig.update_xaxes(
  345. # range=[all_timestamps[0], all_timestamps[-1]],
  346. # rangeslider_visible=False,
  347. # type='date'
  348. # )
  349. #
  350. # # 修改这一行:
  351. # # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  352. #
  353. # # 改为:
  354. # try:
  355. # chart_json = fig.to_json()
  356. # print(f"✅ 图表JSON序列化成功,长度: {len(chart_json)}")
  357. # return chart_json
  358. # except Exception as e:
  359. # print(f"❌ 图表序列化失败: {e}")
  360. # # 返回一个简单的错误图表
  361. # error_fig = go.Figure()
  362. # error_fig.update_layout(title='Chart Rendering Error')
  363. # return error_fig.to_json()
  364. def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
  365. """Create prediction chart"""
  366. print(f"🔍 创建图表调试:")
  367. print(f" 历史数据: {len(df) if df is not None else 0} 行")
  368. print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
  369. print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
  370. # 确保数据不为空
  371. if pred_df is None or len(pred_df) == 0:
  372. print("⚠️ 警告: 预测数据为空!")
  373. # 创建空图表
  374. fig = go.Figure()
  375. fig.update_layout(title='No prediction data available')
  376. return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  377. # Use specified historical data start position, not always from the beginning of df
  378. if historical_start_idx + lookback + pred_len <= len(df):
  379. # Display lookback historical points + pred_len prediction points starting from specified position
  380. historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
  381. prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len)
  382. else:
  383. # If data is insufficient, adjust to maximum available range
  384. available_lookback = min(lookback, len(df) - historical_start_idx)
  385. available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
  386. historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
  387. prediction_range = range(historical_start_idx + available_lookback,
  388. historical_start_idx + available_lookback + available_pred_len)
  389. # Create chart
  390. fig = go.Figure()
  391. # Add historical data (candlestick chart)
  392. fig.add_trace(go.Candlestick(
  393. x=historical_df[
  394. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  395. open=historical_df['open'].tolist(),
  396. high=historical_df['high'].tolist(),
  397. low=historical_df['low'].tolist(),
  398. close=historical_df['close'].tolist(),
  399. name='Historical Data (400 data points)',
  400. increasing_line_color='#26A69A',
  401. decreasing_line_color='#EF5350'
  402. ))
  403. # Add prediction data (candlestick chart)
  404. if pred_df is not None and len(pred_df) > 0:
  405. # Calculate prediction data timestamps - ensure continuity with historical data
  406. if 'timestamps' in df.columns and len(historical_df) > 0:
  407. # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
  408. last_timestamp = historical_df['timestamps'].iloc[-1]
  409. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
  410. pred_timestamps = pd.date_range(
  411. start=last_timestamp + time_diff,
  412. periods=len(pred_df),
  413. freq=time_diff
  414. )
  415. else:
  416. # If no timestamps, use index
  417. pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
  418. fig.add_trace(go.Candlestick(
  419. x=pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps),
  420. open=pred_df['open'].tolist(),
  421. high=pred_df['high'].tolist(),
  422. low=pred_df['low'].tolist(),
  423. close=pred_df['close'].tolist(),
  424. name='Prediction Data (120 data points)',
  425. increasing_line_color='#66BB6A',
  426. decreasing_line_color='#FF7043'
  427. ))
  428. # Add actual data for comparison (if exists)
  429. if actual_df is not None and len(actual_df) > 0:
  430. # Actual data should be in the same time period as prediction data
  431. if 'timestamps' in df.columns:
  432. # Actual data should use the same timestamps as prediction data to ensure time alignment
  433. if 'pred_timestamps' in locals():
  434. actual_timestamps = pred_timestamps
  435. else:
  436. # If no prediction timestamps, calculate from the last timestamp of historical data
  437. if len(historical_df) > 0:
  438. last_timestamp = historical_df['timestamps'].iloc[-1]
  439. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
  440. hours=1)
  441. actual_timestamps = pd.date_range(
  442. start=last_timestamp + time_diff,
  443. periods=len(actual_df),
  444. freq=time_diff
  445. )
  446. else:
  447. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  448. else:
  449. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  450. fig.add_trace(go.Candlestick(
  451. x=actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
  452. open=actual_df['open'].tolist(),
  453. high=actual_df['high'].tolist(),
  454. low=actual_df['low'].tolist(),
  455. close=actual_df['close'].tolist(),
  456. name='Actual Data (120 data points)',
  457. increasing_line_color='#FF9800',
  458. decreasing_line_color='#F44336'
  459. ))
  460. # Update layout
  461. fig.update_layout(
  462. title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
  463. xaxis_title='Time',
  464. yaxis_title='Price',
  465. template='plotly_white',
  466. height=600,
  467. showlegend=True
  468. )
  469. # Ensure x-axis time continuity
  470. if 'timestamps' in historical_df.columns:
  471. # Get all timestamps and sort them
  472. all_timestamps = []
  473. if len(historical_df) > 0:
  474. all_timestamps.extend(historical_df['timestamps'].tolist())
  475. if 'pred_timestamps' in locals():
  476. all_timestamps.extend(
  477. pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps))
  478. if 'actual_timestamps' in locals():
  479. all_timestamps.extend(
  480. actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps))
  481. if all_timestamps:
  482. all_timestamps = sorted(all_timestamps)
  483. fig.update_xaxes(
  484. range=[all_timestamps[0], all_timestamps[-1]],
  485. rangeslider_visible=False,
  486. type='date'
  487. )
  488. # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
  489. try:
  490. chart_json = fig.to_json()
  491. print(f"✅ 图表数据序列化完成,长度: {len(chart_json)}")
  492. return chart_json
  493. except Exception as e:
  494. print(f"❌ 图表序列化失败: {e}")
  495. error_fig = go.Figure()
  496. error_fig.update_layout(title='Chart Rendering Error')
  497. return error_fig.to_json()
  498. # 计算技术指标
  499. def calculate_indicators(df):
  500. indicators = {}
  501. # 计算移动平均线 (MA)
  502. indicators['ma5'] = df['close'].rolling(window=5).mean()
  503. indicators['ma10'] = df['close'].rolling(window=10).mean()
  504. indicators['ma20'] = df['close'].rolling(window=20).mean()
  505. # 计算MACD
  506. exp12 = df['close'].ewm(span=12, adjust=False).mean()
  507. exp26 = df['close'].ewm(span=26, adjust=False).mean()
  508. indicators['macd'] = exp12 - exp26
  509. indicators['signal'] = indicators['macd'].ewm(span=9, adjust=False).mean()
  510. indicators['macd_hist'] = indicators['macd'] - indicators['signal']
  511. # 计算RSI
  512. delta = df['close'].diff()
  513. gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
  514. loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
  515. rs = gain / loss
  516. indicators['rsi'] = 100 - (100 / (1 + rs))
  517. # 计算布林带
  518. indicators['bb_mid'] = df['close'].rolling(window=20).mean()
  519. indicators['bb_std'] = df['close'].rolling(window=20).std()
  520. indicators['bb_upper'] = indicators['bb_mid'] + 2 * indicators['bb_std']
  521. indicators['bb_lower'] = indicators['bb_mid'] - 2 * indicators['bb_std']
  522. # 计算随机震荡指标
  523. low_min = df['low'].rolling(window=14).min()
  524. high_max = df['high'].rolling(window=14).max()
  525. indicators['stoch_k'] = 100 * ((df['close'] - low_min) / (high_max - low_min))
  526. indicators['stoch_d'] = indicators['stoch_k'].rolling(window=3).mean()
  527. # 滚动窗口均值策略
  528. indicators['rwms_window'] = 90
  529. indicators['rwms_mean'] = df['close'].rolling(window=90).mean()
  530. indicators['rwms_signal'] = (df['close'] > indicators['rwms_mean']).astype(int)
  531. # 三重指数平均(TRIX)策略
  532. ema1 = df['close'].ewm(span=12, adjust=False).mean()
  533. ema2 = ema1.ewm(span=12, adjust=False).mean()
  534. ema3 = ema2.ewm(span=12, adjust=False).mean()
  535. indicators['trix'] = (ema3 - ema3.shift(1)) / ema3.shift(1) * 100
  536. indicators['trix_signal'] = indicators['trix'].ewm(span=9, adjust=False).mean()
  537. return indicators
  538. # 技术指标图表绘制
  539. def create_technical_chart(df, pred_df, lookback, pred_len, diagram_type, actual_df=None, historical_start_idx=0):
  540. print(f" 🔍 数据内容: {len(df) if df is not None else 0} 行")
  541. print(f" 🔍 图表类型: {diagram_type}")
  542. # 数据范围
  543. if historical_start_idx + lookback <= len(df):
  544. historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
  545. else:
  546. available_lookback = min(lookback, len(df) - historical_start_idx)
  547. historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
  548. # 计算指标
  549. historical_indicators = calculate_indicators(historical_df)
  550. fig = go.Figure()
  551. # 成交量图表
  552. if diagram_type == 'Volume Chart (VOL)':
  553. fig.add_trace(go.Bar(
  554. x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  555. y = historical_df['volume'].tolist() if 'volume' in historical_df.columns else [],
  556. name = 'Historical Volume',
  557. marker_color='#42A5F5'
  558. ))
  559. if actual_df is not None and len(actual_df) > 0 and 'volume' in actual_df.columns:
  560. if 'timestamps' in df.columns and len(historical_df) > 0:
  561. last_timestamp = historical_df['timestamps'].iloc[-1]
  562. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
  563. hours=1)
  564. actual_timestamps = pd.date_range(start=last_timestamp + time_diff, periods=len(actual_df),freq=time_diff)
  565. else:
  566. actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
  567. fig.add_trace(go.Bar(
  568. x = actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
  569. y = actual_df['volume'].tolist(),
  570. name = 'Actual Volume',
  571. marker_color='#FF9800'
  572. ))
  573. fig.update_layout(yaxis_title='Volume')
  574. # 移动平均线
  575. elif diagram_type == 'Moving Average (MA)':
  576. fig.add_trace(go.Scatter(
  577. x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  578. y = historical_indicators['ma5'],
  579. name='MA5',
  580. line=dict(color='#26A69A', width=1)
  581. ))
  582. fig.add_trace(go.Scatter(
  583. x = historical_df[
  584. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  585. y = historical_indicators['ma10'],
  586. name = 'MA10',
  587. line = dict(color = '#42A5F5', width = 1)
  588. ))
  589. fig.add_trace(go.Scatter(
  590. x = historical_df[
  591. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  592. y = historical_indicators['ma20'],
  593. name = 'MA20',
  594. line = dict(color = '#7E57C2', width = 1)
  595. ))
  596. fig.add_trace(go.Scatter(
  597. x = historical_df[
  598. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  599. y = historical_df['close'],
  600. name = 'Close Price',
  601. line = dict(color = '#212121', width = 1, dash = 'dash')
  602. ))
  603. fig.update_layout(yaxis_title = 'Price')
  604. # MACD指标
  605. elif diagram_type == 'MACD Indicator (MACD)':
  606. fig.add_trace(go.Scatter(
  607. x = historical_df[
  608. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  609. y = historical_indicators['macd'],
  610. name = 'MACD',
  611. line = dict(color = '#26A69A', width = 1)
  612. ))
  613. fig.add_trace(go.Scatter(
  614. x = historical_df[
  615. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  616. y = historical_indicators['signal'],
  617. name = 'Signal',
  618. line = dict(color = '#EF5350', width = 1)
  619. ))
  620. fig.add_trace(go.Bar(
  621. x = historical_df[
  622. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  623. y = historical_indicators['macd_hist'],
  624. name = 'MACD Histogram',
  625. marker_color = '#42A5F5'
  626. ))
  627. fig.add_hline(y = 0, line_dash = "dash", line_color = "gray")
  628. fig.update_layout(yaxis_title = 'MACD')
  629. # RSI指标
  630. elif diagram_type == 'RSI Indicator (RSI)':
  631. fig.add_trace(go.Scatter(
  632. x = historical_df[
  633. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  634. y = historical_indicators['rsi'],
  635. name = 'RSI',
  636. line = dict(color = '#26A69A', width = 1)
  637. ))
  638. fig.add_hline(y = 70, line_dash = "dash", line_color = "red", name = 'Overbought')
  639. fig.add_hline(y = 30, line_dash = "dash", line_color = "green", name = 'Oversold')
  640. fig.update_layout(yaxis_title = 'RSI', yaxis_range = [0, 100])
  641. # 布林带
  642. elif diagram_type == 'Bollinger Bands (BB)':
  643. fig.add_trace(go.Scatter(
  644. x = historical_df[
  645. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  646. y = historical_indicators['bb_upper'],
  647. name = 'Upper Band',
  648. line = dict(color = '#EF5350', width = 1)
  649. ))
  650. fig.add_trace(go.Scatter(
  651. x = historical_df[
  652. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  653. y = historical_indicators['bb_mid'],
  654. name = 'Middle Band (MA20)',
  655. line = dict(color = '#42A5F5', width = 1)
  656. ))
  657. fig.add_trace(go.Scatter(
  658. x = historical_df[
  659. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  660. y = historical_indicators['bb_lower'],
  661. name = 'Lower Band',
  662. line = dict(color = '#26A69A', width = 1)
  663. ))
  664. fig.add_trace(go.Scatter(
  665. x = historical_df[
  666. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  667. y = historical_df['close'],
  668. name = 'Close Price',
  669. line = dict(color = '#212121', width = 1)
  670. ))
  671. fig.update_layout(yaxis_title = 'Price')
  672. # 随机震荡指标
  673. elif diagram_type == 'Stochastic Oscillator (STOCH)':
  674. fig.add_trace(go.Scatter(
  675. x = historical_df[
  676. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  677. y = historical_indicators['stoch_k'],
  678. name = '%K',
  679. line = dict(color = '#26A69A', width = 1)
  680. ))
  681. fig.add_trace(go.Scatter(
  682. x = historical_df[
  683. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  684. y = historical_indicators['stoch_d'],
  685. name = '%D',
  686. line = dict(color = '#EF5350', width = 1)
  687. ))
  688. fig.add_hline(y = 80, line_dash = "dash", line_color = "red", name = 'Overbought')
  689. fig.add_hline(y = 20, line_dash = "dash", line_color = "green", name = 'Oversold')
  690. fig.update_layout(yaxis_title = 'Stochastic', yaxis_range = [0, 100])
  691. # 滚动窗口均值策略
  692. elif diagram_type == 'Rolling Window Mean Strategy':
  693. fig.add_trace(go.Scatter(
  694. x = historical_df[
  695. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  696. y = historical_df['close'],
  697. name = 'Close Price',
  698. line = dict(color = '#212121', width = 1.5)
  699. ))
  700. fig.add_trace(go.Scatter(
  701. x = historical_df[
  702. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  703. y = historical_indicators['rwms_mean'],
  704. name = f'Rolling Mean ({historical_indicators["rwms_window"]} periods)',
  705. line = dict(color = '#42A5F5', width = 1.5, dash = 'dash')
  706. ))
  707. buy_signals = historical_df[historical_indicators['rwms_signal'] == 1]
  708. fig.add_trace(go.Scatter(
  709. x = buy_signals['timestamps'].tolist() if 'timestamps' in buy_signals.columns else buy_signals.index.tolist(),
  710. y = buy_signals['close'],
  711. mode = 'markers',
  712. name = 'Buy Signal',
  713. marker = dict(color = '#26A69A', size = 8, symbol = 'triangle-up')
  714. ))
  715. sell_signals = historical_df[historical_indicators['rwms_signal'] == 0]
  716. fig.add_trace(go.Scatter(
  717. x = sell_signals[
  718. 'timestamps'].tolist() if 'timestamps' in sell_signals.columns else sell_signals.index.tolist(),
  719. y = sell_signals['close'],
  720. mode = 'markers',
  721. name = 'Sell Signal',
  722. marker = dict(color = '#EF5350', size = 8, symbol = 'triangle-down')
  723. ))
  724. fig.update_layout(
  725. yaxis_title = 'Price',
  726. title = f'Rolling Window Mean Strategy (Window Size: {historical_indicators["rwms_window"]})'
  727. )
  728. # TRIX指标图表
  729. elif diagram_type == 'TRIX Indicator (TRIX)':
  730. fig.add_trace(go.Scatter(
  731. x=historical_df[
  732. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  733. y=historical_indicators['trix'],
  734. name='TRIX',
  735. line=dict(color='#26A69A', width=1)
  736. ))
  737. fig.add_trace(go.Scatter(
  738. x=historical_df[
  739. 'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
  740. y=historical_indicators['trix_signal'],
  741. name='TRIX Signal',
  742. line=dict(color='#EF5350', width=1)
  743. ))
  744. fig.add_hline(y=0, line_dash="dash", line_color="gray")
  745. fig.update_layout(
  746. yaxis_title='TRIX (%)',
  747. title='Triple Exponential Average (TRIX) Strategy'
  748. )
  749. # 布局设置
  750. fig.update_layout(
  751. title = f'{diagram_type} - Technical Indicator',
  752. xaxis_title = 'Time',
  753. template = 'plotly_white',
  754. height = 400,
  755. showlegend = True,
  756. margin = dict(t = 50, b = 30)
  757. )
  758. if 'timestamps' in historical_df.columns:
  759. all_timestamps = historical_df['timestamps'].tolist()
  760. if actual_df is not None and len(actual_df) > 0 and 'timestamps' in df.columns:
  761. if 'actual_timestamps' in locals():
  762. all_timestamps.extend(actual_timestamps.tolist())
  763. if all_timestamps:
  764. all_timestamps = sorted(all_timestamps)
  765. fig.update_xaxes(
  766. range=[all_timestamps[0], all_timestamps[-1]],
  767. rangeslider_visible=False,
  768. type='date'
  769. )
  770. try:
  771. chart_json = fig.to_json()
  772. print(f"✅ 技术指标图表序列化完成,长度: {len(chart_json)}")
  773. return chart_json
  774. except Exception as e:
  775. print(f"❌ 技术指标图表序列化失败: {e}")
  776. error_fig = go.Figure()
  777. error_fig.update_layout(title='Chart Rendering Error')
  778. return error_fig.to_json()
  779. @app.route('/')
  780. def index():
  781. """Home page"""
  782. return render_template('index.html')
  783. @app.route('/api/data-files')
  784. def get_data_files():
  785. """Get available data file list"""
  786. data_files = load_data_files()
  787. return jsonify(data_files)
  788. @app.route('/api/load-data', methods=['POST'])
  789. def load_data():
  790. """Load data file"""
  791. try:
  792. data = request.get_json()
  793. file_path = data.get('file_path')
  794. if not file_path:
  795. return jsonify({'error': 'File path cannot be empty'}), 400
  796. df, error = load_data_file(file_path)
  797. if error:
  798. return jsonify({'error': error}), 400
  799. # Detect data time frequency
  800. def detect_timeframe(df):
  801. if len(df) < 2:
  802. return "Unknown"
  803. time_diffs = []
  804. for i in range(1, min(10, len(df))): # Check first 10 time differences
  805. diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
  806. time_diffs.append(diff)
  807. if not time_diffs:
  808. return "Unknown"
  809. # Calculate average time difference
  810. avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
  811. # Convert to readable format
  812. if avg_diff < pd.Timedelta(minutes=1):
  813. return f"{avg_diff.total_seconds():.0f} seconds"
  814. elif avg_diff < pd.Timedelta(hours=1):
  815. return f"{avg_diff.total_seconds() / 60:.0f} minutes"
  816. elif avg_diff < pd.Timedelta(days=1):
  817. return f"{avg_diff.total_seconds() / 3600:.0f} hours"
  818. else:
  819. return f"{avg_diff.days} days"
  820. # Return data information
  821. data_info = {
  822. 'rows': len(df),
  823. 'columns': list(df.columns),
  824. 'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
  825. 'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
  826. 'price_range': {
  827. 'min': float(df[['open', 'high', 'low', 'close']].min().min()),
  828. 'max': float(df[['open', 'high', 'low', 'close']].max().max())
  829. },
  830. 'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
  831. 'timeframe': detect_timeframe(df)
  832. }
  833. return jsonify({
  834. 'success': True,
  835. 'data_info': data_info,
  836. 'message': f'Successfully loaded data, total {len(df)} rows'
  837. })
  838. except Exception as e:
  839. return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
  840. # @app.route('/api/predict', methods=['POST'])
  841. # def predict():
  842. # """Perform prediction"""
  843. # try:
  844. # data = request.get_json()
  845. # file_path = data.get('file_path')
  846. # lookback = int(data.get('lookback', 400))
  847. # pred_len = int(data.get('pred_len', 120))
  848. #
  849. # # Get prediction quality parameters
  850. # temperature = float(data.get('temperature', 1.0))
  851. # top_p = float(data.get('top_p', 0.9))
  852. # sample_count = int(data.get('sample_count', 1))
  853. #
  854. # if not file_path:
  855. # return jsonify({'error': 'File path cannot be empty'}), 400
  856. #
  857. # # Load data
  858. # df, error = load_data_file(file_path)
  859. # if error:
  860. # return jsonify({'error': error}), 400
  861. #
  862. # if len(df) < lookback:
  863. # return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
  864. #
  865. # # Perform prediction
  866. # if MODEL_AVAILABLE and predictor is not None:
  867. # try:
  868. # # Use real Kronos model
  869. # # Only use necessary columns: OHLCV, excluding amount
  870. # required_cols = ['open', 'high', 'low', 'close']
  871. # if 'volume' in df.columns:
  872. # required_cols.append('volume')
  873. #
  874. # # Process time period selection
  875. # start_date = data.get('start_date')
  876. #
  877. # if start_date:
  878. # # Custom time period - fix logic: use data within selected window
  879. # start_dt = pd.to_datetime(start_date)
  880. #
  881. # # Find data after start time
  882. # mask = df['timestamps'] >= start_dt
  883. # time_range_df = df[mask]
  884. #
  885. # # Ensure sufficient data: lookback + pred_len
  886. # if len(time_range_df) < lookback + pred_len:
  887. # return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
  888. #
  889. # # Use first lookback data points within selected window for prediction
  890. # x_df = time_range_df.iloc[:lookback][required_cols]
  891. # x_timestamp = time_range_df.iloc[:lookback]['timestamps']
  892. #
  893. # # Use last pred_len data points within selected window as actual values
  894. # y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
  895. #
  896. # # Calculate actual time period length
  897. # start_timestamp = time_range_df['timestamps'].iloc[0]
  898. # end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
  899. # time_span = end_timestamp - start_timestamp
  900. #
  901. # prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
  902. # else:
  903. # # Use latest data
  904. # x_df = df.iloc[:lookback][required_cols]
  905. # x_timestamp = df.iloc[:lookback]['timestamps']
  906. # y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
  907. # prediction_type = "Kronos model prediction (latest data)"
  908. #
  909. # # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
  910. # if isinstance(x_timestamp, pd.DatetimeIndex):
  911. # x_timestamp = pd.Series(x_timestamp, name='timestamps')
  912. # if isinstance(y_timestamp, pd.DatetimeIndex):
  913. # y_timestamp = pd.Series(y_timestamp, name='timestamps')
  914. #
  915. # # # 在 pred_df = predictor.predict(...) 之前添加:
  916. # # print("🔍 调试预测输入:")
  917. # # print(f"x_df 类型: {type(x_df)}")
  918. # # print(f"x_df 形状: {x_df.shape}")
  919. # # print(f"x_df 列名: {x_df.columns.tolist()}")
  920. # # print(f"x_df 数据类型: {x_df.dtypes}")
  921. # #
  922. # # print(f"x_timestamp 类型: {type(x_timestamp)}")
  923. # # print(f"x_timestamp 长度: {len(x_timestamp)}")
  924. # #
  925. # # print(f"y_timestamp 类型: {type(y_timestamp)}")
  926. # # print(f"y_timestamp 长度: {len(y_timestamp)}")
  927. # #
  928. # # # 检查数据内容
  929. # # print("x_df 前5行:")
  930. # # print(x_df.head())
  931. # #
  932. # # # 在调用 predict 前确保数据格式正确
  933. # # print(f"x_df 实际形状: {x_df.shape}") # 确认是 (400, 5)
  934. # # print(f"x_df 数值类型: {x_df.values.dtype}")
  935. # #
  936. # # # 确保没有隐藏的索引列
  937. # # x_df_clean = x_df.reset_index(drop=True)
  938. # # print(f"重置索引后形状: {x_df_clean.shape}")
  939. # #
  940. # # # 在调用 predict 之前添加更详细的调试
  941. # # print("🔍 深入调试 KronosPredictor:")
  942. # #
  943. # # # 检查 predictor 的属性
  944. # # print(f"predictor 类型: {type(predictor)}")
  945. # # print(f"predictor 设备: {getattr(predictor, 'device', 'unknown')}")
  946. # # print(f"predictor max_context: {getattr(predictor, 'max_context', 'unknown')}")
  947. # #
  948. # # # 检查模型输入维度
  949. # # if hasattr(predictor, 'model'):
  950. # # model = predictor.model
  951. # # print(f"模型参数示例:")
  952. # # for name, param in model.named_parameters():
  953. # # if 'weight' in name and param.dim() == 2:
  954. # # print(f" {name}: {param.shape}")
  955. # # break
  956. # #
  957. # # # 尝试手动准备数据
  958. # # try:
  959. # # # 将数据转换为 tensor 看看维度
  960. # # import torch
  961. # # x_tensor = torch.tensor(x_df.values, dtype=torch.float32)
  962. # # print(f"Tensor 形状: {x_tensor.shape}")
  963. # #
  964. # # # 检查 tokenizer 的输入维度
  965. # # if hasattr(predictor, 'tokenizer'):
  966. # # tokenizer = predictor.tokenizer
  967. # # print(f"tokenizer 输入维度: {getattr(tokenizer, 'd_in', 'unknown')}")
  968. # #
  969. # # except Exception as e:
  970. # # print(f"Tensor 转换错误: {e}")
  971. # #
  972. # # # 在 predict 调用前测试 tokenizer
  973. # # try:
  974. # # # 测试 tokenizer 是否能正确处理数据
  975. # # test_data = x_df.values # (400, 5)
  976. # # print(f"测试数据形状: {test_data.shape}")
  977. # #
  978. # # # 尝试手动调用 tokenizer
  979. # # if hasattr(predictor.tokenizer, 'encode'):
  980. # # encoded = predictor.tokenizer.encode(test_data)
  981. # # print(f"Tokenized 数据形状: {encoded.shape}")
  982. # # else:
  983. # # print("Tokenizer 没有 encode 方法")
  984. # #
  985. # # except Exception as e:
  986. # # print(f"Tokenizer 测试错误: {e}")
  987. #
  988. # pred_df = predictor.predict(
  989. # df=x_df,
  990. # x_timestamp=x_timestamp,
  991. # y_timestamp=y_timestamp,
  992. # pred_len=pred_len,
  993. # T=temperature,
  994. # top_p=top_p,
  995. # sample_count=sample_count
  996. # )
  997. #
  998. # except Exception as e:
  999. # return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
  1000. # else:
  1001. # return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
  1002. #
  1003. # # Prepare actual data for comparison (if exists)
  1004. # actual_data = []
  1005. # actual_df = None
  1006. #
  1007. # if start_date: # Custom time period
  1008. # # Fix logic: use data within selected window
  1009. # # Prediction uses first 400 data points within selected window
  1010. # # Actual data should be last 120 data points within selected window
  1011. # start_dt = pd.to_datetime(start_date)
  1012. #
  1013. # # Find data starting from start_date
  1014. # mask = df['timestamps'] >= start_dt
  1015. # time_range_df = df[mask]
  1016. #
  1017. # if len(time_range_df) >= lookback + pred_len:
  1018. # # Get last 120 data points within selected window as actual values
  1019. # actual_df = time_range_df.iloc[lookback:lookback+pred_len]
  1020. #
  1021. # for i, (_, row) in enumerate(actual_df.iterrows()):
  1022. # actual_data.append({
  1023. # 'timestamp': row['timestamps'].isoformat(),
  1024. # 'open': float(row['open']),
  1025. # 'high': float(row['high']),
  1026. # 'low': float(row['low']),
  1027. # 'close': float(row['close']),
  1028. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1029. # 'amount': float(row['amount']) if 'amount' in row else 0
  1030. # })
  1031. # else: # Latest data
  1032. # # Prediction uses first 400 data points
  1033. # # Actual data should be 120 data points after first 400 data points
  1034. # if len(df) >= lookback + pred_len:
  1035. # actual_df = df.iloc[lookback:lookback+pred_len]
  1036. # for i, (_, row) in enumerate(actual_df.iterrows()):
  1037. # actual_data.append({
  1038. # 'timestamp': row['timestamps'].isoformat(),
  1039. # 'open': float(row['open']),
  1040. # 'high': float(row['high']),
  1041. # 'low': float(row['low']),
  1042. # 'close': float(row['close']),
  1043. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1044. # 'amount': float(row['amount']) if 'amount' in row else 0
  1045. # })
  1046. #
  1047. # # Create chart - pass historical data start position
  1048. # if start_date:
  1049. # # Custom time period: find starting position of historical data in original df
  1050. # start_dt = pd.to_datetime(start_date)
  1051. # mask = df['timestamps'] >= start_dt
  1052. # historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
  1053. # else:
  1054. # # Latest data: start from beginning
  1055. # historical_start_idx = 0
  1056. #
  1057. # chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
  1058. #
  1059. # # Prepare prediction result data - fix timestamp calculation logic
  1060. # if 'timestamps' in df.columns:
  1061. # if start_date:
  1062. # # Custom time period: use selected window data to calculate timestamps
  1063. # start_dt = pd.to_datetime(start_date)
  1064. # mask = df['timestamps'] >= start_dt
  1065. # time_range_df = df[mask]
  1066. #
  1067. # if len(time_range_df) >= lookback:
  1068. # # Calculate prediction timestamps starting from last time point of selected window
  1069. # last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
  1070. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1071. # future_timestamps = pd.date_range(
  1072. # start=last_timestamp + time_diff,
  1073. # periods=pred_len,
  1074. # freq=time_diff
  1075. # )
  1076. # else:
  1077. # future_timestamps = []
  1078. # else:
  1079. # # Latest data: calculate from last time point of entire data file
  1080. # last_timestamp = df['timestamps'].iloc[-1]
  1081. # time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1082. # future_timestamps = pd.date_range(
  1083. # start=last_timestamp + time_diff,
  1084. # periods=pred_len,
  1085. # freq=time_diff
  1086. # )
  1087. # else:
  1088. # future_timestamps = range(len(df), len(df) + pred_len)
  1089. #
  1090. # prediction_results = []
  1091. # for i, (_, row) in enumerate(pred_df.iterrows()):
  1092. # prediction_results.append({
  1093. # 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
  1094. # 'open': float(row['open']),
  1095. # 'high': float(row['high']),
  1096. # 'low': float(row['low']),
  1097. # 'close': float(row['close']),
  1098. # 'volume': float(row['volume']) if 'volume' in row else 0,
  1099. # 'amount': float(row['amount']) if 'amount' in row else 0
  1100. # })
  1101. #
  1102. # # Save prediction results to file
  1103. # try:
  1104. # save_prediction_results(
  1105. # file_path=file_path,
  1106. # prediction_type=prediction_type,
  1107. # prediction_results=prediction_results,
  1108. # actual_data=actual_data,
  1109. # input_data=x_df,
  1110. # prediction_params={
  1111. # 'lookback': lookback,
  1112. # 'pred_len': pred_len,
  1113. # 'temperature': temperature,
  1114. # 'top_p': top_p,
  1115. # 'sample_count': sample_count,
  1116. # 'start_date': start_date if start_date else 'latest'
  1117. # }
  1118. # )
  1119. # except Exception as e:
  1120. # print(f"Failed to save prediction results: {e}")
  1121. #
  1122. # return jsonify({
  1123. # 'success': True,
  1124. # 'prediction_type': prediction_type,
  1125. # 'chart': chart_json,
  1126. # 'prediction_results': prediction_results,
  1127. # 'actual_data': actual_data,
  1128. # 'has_comparison': len(actual_data) > 0,
  1129. # 'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1130. # })
  1131. #
  1132. # except Exception as e:
  1133. # return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
  1134. @app.route('/api/predict', methods=['POST'])
  1135. def predict():
  1136. """Perform prediction"""
  1137. try:
  1138. data = request.get_json()
  1139. file_path = data.get('file_path')
  1140. lookback = int(data.get('lookback', 400))
  1141. pred_len = int(data.get('pred_len', 120))
  1142. # Get prediction quality parameters
  1143. temperature = float(data.get('temperature', 1.0))
  1144. top_p = float(data.get('top_p', 0.9))
  1145. sample_count = int(data.get('sample_count', 1))
  1146. if not file_path:
  1147. return jsonify({'error': 'File path cannot be empty'}), 400
  1148. # Load data
  1149. df, error = load_data_file(file_path)
  1150. if error:
  1151. return jsonify({'error': error}), 400
  1152. if len(df) < lookback:
  1153. return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
  1154. # Perform prediction
  1155. if MODEL_AVAILABLE and predictor is not None:
  1156. try:
  1157. # Use real Kronos model
  1158. # Only use necessary columns: OHLCV + amount
  1159. required_cols = ['open', 'high', 'low', 'close', 'volume', 'amount']
  1160. # Process time period selection
  1161. start_date = data.get('start_date')
  1162. if start_date:
  1163. # Custom time period - fix logic: use data within selected window
  1164. start_dt = pd.to_datetime(start_date)
  1165. # Find data after start time
  1166. mask = df['timestamps'] >= start_dt
  1167. time_range_df = df[mask]
  1168. # Ensure sufficient data: lookback + pred_len
  1169. if len(time_range_df) < lookback + pred_len:
  1170. return jsonify({
  1171. 'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
  1172. # Use first lookback data points within selected window for prediction
  1173. x_df = time_range_df.iloc[:lookback][required_cols]
  1174. x_timestamp = time_range_df.iloc[:lookback]['timestamps']
  1175. # Use last pred_len data points within selected window as actual values
  1176. y_timestamp = time_range_df.iloc[lookback:lookback + pred_len]['timestamps']
  1177. # Calculate actual time period length
  1178. start_timestamp = time_range_df['timestamps'].iloc[0]
  1179. end_timestamp = time_range_df['timestamps'].iloc[lookback + pred_len - 1]
  1180. time_span = end_timestamp - start_timestamp
  1181. prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
  1182. else:
  1183. # Use latest data
  1184. x_df = df.iloc[:lookback][required_cols]
  1185. x_timestamp = df.iloc[:lookback]['timestamps']
  1186. y_timestamp = df.iloc[lookback:lookback + pred_len]['timestamps']
  1187. prediction_type = "Kronos model prediction (latest data)"
  1188. # Debug information
  1189. print(f"🔍 传递给predictor的数据列: {x_df.columns.tolist()}")
  1190. print(f"🔍 数据形状: {x_df.shape}")
  1191. print(f"🔍 数据样例:")
  1192. print(x_df.head(2))
  1193. # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
  1194. if isinstance(x_timestamp, pd.DatetimeIndex):
  1195. x_timestamp = pd.Series(x_timestamp, name='timestamps')
  1196. if isinstance(y_timestamp, pd.DatetimeIndex):
  1197. y_timestamp = pd.Series(y_timestamp, name='timestamps')
  1198. pred_df = predictor.predict(
  1199. df=x_df,
  1200. x_timestamp=x_timestamp,
  1201. y_timestamp=y_timestamp,
  1202. pred_len=pred_len,
  1203. T=temperature,
  1204. top_p=top_p,
  1205. sample_count=sample_count
  1206. )
  1207. except Exception as e:
  1208. return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
  1209. else:
  1210. return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
  1211. # Prepare actual data for comparison (if exists)
  1212. actual_data = []
  1213. actual_df = None
  1214. if start_date: # Custom time period
  1215. # Fix logic: use data within selected window
  1216. # Prediction uses first 400 data points within selected window
  1217. # Actual data should be last 120 data points within selected window
  1218. start_dt = pd.to_datetime(start_date)
  1219. # Find data starting from start_date
  1220. mask = df['timestamps'] >= start_dt
  1221. time_range_df = df[mask]
  1222. if len(time_range_df) >= lookback + pred_len:
  1223. # Get last 120 data points within selected window as actual values
  1224. actual_df = time_range_df.iloc[lookback:lookback + pred_len]
  1225. for i, (_, row) in enumerate(actual_df.iterrows()):
  1226. actual_data.append({
  1227. 'timestamp': row['timestamps'].isoformat(),
  1228. 'open': float(row['open']),
  1229. 'high': float(row['high']),
  1230. 'low': float(row['low']),
  1231. 'close': float(row['close']),
  1232. 'volume': float(row['volume']) if 'volume' in row else 0,
  1233. 'amount': float(row['amount']) if 'amount' in row else 0
  1234. })
  1235. else: # Latest data
  1236. # Prediction uses first 400 data points
  1237. # Actual data should be 120 data points after first 400 data points
  1238. if len(df) >= lookback + pred_len:
  1239. actual_df = df.iloc[lookback:lookback + pred_len]
  1240. for i, (_, row) in enumerate(actual_df.iterrows()):
  1241. actual_data.append({
  1242. 'timestamp': row['timestamps'].isoformat(),
  1243. 'open': float(row['open']),
  1244. 'high': float(row['high']),
  1245. 'low': float(row['low']),
  1246. 'close': float(row['close']),
  1247. 'volume': float(row['volume']) if 'volume' in row else 0,
  1248. 'amount': float(row['amount']) if 'amount' in row else 0
  1249. })
  1250. # Create chart - pass historical data start position
  1251. if start_date:
  1252. # Custom time period: find starting position of historical data in original df
  1253. start_dt = pd.to_datetime(start_date)
  1254. mask = df['timestamps'] >= start_dt
  1255. historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
  1256. else:
  1257. # Latest data: start from beginning
  1258. historical_start_idx = 0
  1259. chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
  1260. # Prepare prediction result data - fix timestamp calculation logic
  1261. if 'timestamps' in df.columns:
  1262. if start_date:
  1263. # Custom time period: use selected window data to calculate timestamps
  1264. start_dt = pd.to_datetime(start_date)
  1265. mask = df['timestamps'] >= start_dt
  1266. time_range_df = df[mask]
  1267. if len(time_range_df) >= lookback:
  1268. # Calculate prediction timestamps starting from last time point of selected window
  1269. last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
  1270. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1271. future_timestamps = pd.date_range(
  1272. start=last_timestamp + time_diff,
  1273. periods=pred_len,
  1274. freq=time_diff
  1275. )
  1276. else:
  1277. future_timestamps = []
  1278. else:
  1279. # Latest data: calculate from last time point of entire data file
  1280. last_timestamp = df['timestamps'].iloc[-1]
  1281. time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
  1282. future_timestamps = pd.date_range(
  1283. start=last_timestamp + time_diff,
  1284. periods=pred_len,
  1285. freq=time_diff
  1286. )
  1287. else:
  1288. future_timestamps = range(len(df), len(df) + pred_len)
  1289. prediction_results = []
  1290. for i, (_, row) in enumerate(pred_df.iterrows()):
  1291. prediction_results.append({
  1292. 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
  1293. 'open': float(row['open']),
  1294. 'high': float(row['high']),
  1295. 'low': float(row['low']),
  1296. 'close': float(row['close']),
  1297. 'volume': float(row['volume']) if 'volume' in row else 0,
  1298. 'amount': float(row['amount']) if 'amount' in row else 0
  1299. })
  1300. # Save prediction results to file
  1301. try:
  1302. save_prediction_results(
  1303. file_path=file_path,
  1304. prediction_type=prediction_type,
  1305. prediction_results=prediction_results,
  1306. actual_data=actual_data,
  1307. input_data=x_df,
  1308. prediction_params={
  1309. 'lookback': lookback,
  1310. 'pred_len': pred_len,
  1311. 'temperature': temperature,
  1312. 'top_p': top_p,
  1313. 'sample_count': sample_count,
  1314. 'start_date': start_date if start_date else 'latest'
  1315. }
  1316. )
  1317. except Exception as e:
  1318. print(f"Failed to save prediction results: {e}")
  1319. # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  1320. # 在返回前添加
  1321. print(f"✅ 预测完成,返回数据:")
  1322. print(f" 成功: {True}")
  1323. print(f" 预测类型: {prediction_type}")
  1324. print(f" 图表数据长度: {len(chart_json)}")
  1325. print(f" 预测结果数量: {len(prediction_results)}")
  1326. print(f" 实际数据数量: {len(actual_data)}")
  1327. print(f" 有比较数据: {len(actual_data) > 0}")
  1328. return jsonify({
  1329. 'success': True,
  1330. 'prediction_type': prediction_type,
  1331. 'chart': chart_json,
  1332. 'prediction_results': prediction_results,
  1333. 'actual_data': actual_data,
  1334. 'has_comparison': len(actual_data) > 0,
  1335. 'message': f'Prediction completed, generated {pred_len} prediction points' + (
  1336. f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1337. })
  1338. # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  1339. # return jsonify({
  1340. # 'success': True,
  1341. # 'prediction_type': prediction_type,
  1342. # 'chart': chart_json,
  1343. # 'prediction_results': prediction_results,
  1344. # 'actual_data': actual_data,
  1345. # 'has_comparison': len(actual_data) > 0,
  1346. # 'message': f'Prediction completed, generated {pred_len} prediction points' + (
  1347. # f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
  1348. # })
  1349. except Exception as e:
  1350. return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
  1351. # @app.route('/api/load-model', methods=['POST'])
  1352. # def load_model():
  1353. # """Load Kronos model"""
  1354. # global tokenizer, model, predictor
  1355. #
  1356. # try:
  1357. # if not MODEL_AVAILABLE:
  1358. # return jsonify({'error': 'Kronos model library not available'}), 400
  1359. #
  1360. # data = request.get_json()
  1361. # model_key = data.get('model_key', 'kronos-small')
  1362. # device = data.get('device', 'cpu')
  1363. #
  1364. # if model_key not in AVAILABLE_MODELS:
  1365. # return jsonify({'error': f'Unsupported model: {model_key}'}), 400
  1366. #
  1367. # model_config = AVAILABLE_MODELS[model_key]
  1368. #
  1369. # # Load tokenizer and model
  1370. # tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
  1371. # model = Kronos.from_pretrained(model_config['model_id'])
  1372. #
  1373. # # Create predictor
  1374. # predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
  1375. #
  1376. # return jsonify({
  1377. # 'success': True,
  1378. # 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
  1379. # 'model_info': {
  1380. # 'name': model_config['name'],
  1381. # 'params': model_config['params'],
  1382. # 'context_length': model_config['context_length'],
  1383. # 'description': model_config['description']
  1384. # }
  1385. # })
  1386. #
  1387. # except Exception as e:
  1388. # return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
  1389. @app.route('/api/load-model', methods=['POST'])
  1390. def load_model():
  1391. global tokenizer, model, predictor
  1392. try:
  1393. if not MODEL_AVAILABLE:
  1394. return jsonify({'error': 'Kronos model library not available'}), 400
  1395. data = request.get_json()
  1396. model_key = data.get('model_key', 'kronos-small')
  1397. device = data.get('device', 'cpu')
  1398. if model_key not in AVAILABLE_MODELS:
  1399. return jsonify({'error': f'Unsupported model: {model_key}'}), 400
  1400. model_config = AVAILABLE_MODELS[model_key]
  1401. print(f"🚀 Loading model from: {model_config['model_id']}")
  1402. model_path = model_config['model_id']
  1403. tokenizer_path = model_config['tokenizer_id']
  1404. if os.path.exists(model_path):
  1405. model_files = os.listdir(model_path)
  1406. print(f"📄 模型目录中的文件: {model_files}")
  1407. # 检查模型路径是否存在
  1408. if not os.path.exists(model_path):
  1409. return jsonify({'error': f'模型路径不存在: {model_path}'}), 400
  1410. try:
  1411. # 直接从本地加载模型
  1412. model = Kronos.from_pretrained(
  1413. model_config['model_id'],
  1414. local_files_only=True
  1415. )
  1416. # 读取模型配置文件获取正确参数
  1417. config_path = os.path.join(model_config['model_id'], 'config.json')
  1418. if os.path.exists(config_path):
  1419. print(f"读取配置文件: {config_path}")
  1420. with open(config_path, 'r') as f:
  1421. config = json.load(f)
  1422. for key, value in config.items():
  1423. print(f" {key}: {value}")
  1424. # 使用配置中的参数创建tokenizer
  1425. tokenizer = KronosTokenizer(
  1426. d_in=6, # OHLC + volume
  1427. d_model=config['d_model'], # 832
  1428. n_heads=config['n_heads'], # 16
  1429. ff_dim=config['ff_dim'], # 2048
  1430. n_enc_layers=config['n_layers'], # 12
  1431. n_dec_layers=config['n_layers'], # 12
  1432. ffn_dropout_p=config['ffn_dropout_p'], # 0.2
  1433. attn_dropout_p=config['attn_dropout_p'], # 0.0
  1434. resid_dropout_p=config['resid_dropout_p'], # 0.2
  1435. s1_bits=config['s1_bits'], # 10
  1436. s2_bits=config['s2_bits'], # 10
  1437. beta=1.0,
  1438. gamma0=1.0,
  1439. gamma=1.0,
  1440. zeta=1.0,
  1441. group_size=1
  1442. )
  1443. else:
  1444. return jsonify({'error': f'Config file not found: {config_path}'}), 400
  1445. except Exception as e:
  1446. return jsonify({'error': f'Failed to load model: {str(e)}'}), 500
  1447. # 创建predictor
  1448. predictor = KronosPredictor(
  1449. model,
  1450. tokenizer,
  1451. device=device,
  1452. max_context=model_config['context_length']
  1453. )
  1454. return jsonify({
  1455. 'success': True,
  1456. 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
  1457. 'model_info': {
  1458. **model_config,
  1459. 'model_path': model_config['model_id'],
  1460. 'abs_model_path': os.path.abspath(model_config['model_id']),
  1461. 'device': device
  1462. }
  1463. })
  1464. except Exception as e:
  1465. import traceback
  1466. print("【API接口错误】")
  1467. print(f"错误类型: {type(e).__name__}")
  1468. print(f"错误信息: {str(e)}")
  1469. traceback.print_exc()
  1470. print("=" * 60)
  1471. return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
  1472. @app.route('/api/available-models')
  1473. def get_available_models():
  1474. """Get available model list"""
  1475. return jsonify({
  1476. 'models': AVAILABLE_MODELS,
  1477. 'model_available': MODEL_AVAILABLE
  1478. })
  1479. @app.route('/api/model-status')
  1480. def get_model_status():
  1481. """Get model status"""
  1482. if MODEL_AVAILABLE:
  1483. if predictor is not None:
  1484. return jsonify({
  1485. 'available': True,
  1486. 'loaded': True,
  1487. 'message': 'Kronos model loaded and available',
  1488. 'current_model': {
  1489. 'name': predictor.model.__class__.__name__,
  1490. 'device': str(next(predictor.model.parameters()).device)
  1491. }
  1492. })
  1493. else:
  1494. return jsonify({
  1495. 'available': True,
  1496. 'loaded': False,
  1497. 'message': 'Kronos model available but not loaded'
  1498. })
  1499. else:
  1500. return jsonify({
  1501. 'available': False,
  1502. 'loaded': False,
  1503. 'message': 'Kronos model library not available, please install related dependencies'
  1504. })
  1505. # 股票数据获取接口
  1506. @app.route('/api/stock-data', methods=['POST'])
  1507. def Stock_Data():
  1508. try:
  1509. data = request.get_json()
  1510. stock_code = data.get('stock_code', '').strip()
  1511. if not stock_code:
  1512. return jsonify({
  1513. 'success': False,
  1514. 'error': f'Stock code cannot be empty'
  1515. }), 400
  1516. if not re.match(r'^[a-z]+\.\d+$', stock_code):
  1517. return jsonify({
  1518. 'success': False,
  1519. 'error': f'The stock code you entered is invalid'
  1520. }), 400
  1521. # 登录 baostock
  1522. lg = bs.login()
  1523. if lg.error_code != '0':
  1524. return jsonify({
  1525. 'success': False,
  1526. 'error': f'Login failed: {lg.error_msg}'
  1527. }), 400
  1528. end_date = datetime.now().strftime('%Y-%m-%d')
  1529. rs = bs.query_history_k_data_plus(
  1530. stock_code,
  1531. "time,open,high,low,close,volume,amount",
  1532. start_date = '2024-06-01',
  1533. end_date = end_date,
  1534. frequency = "5",
  1535. adjustflag = "3"
  1536. )
  1537. if rs.error_code != '0':
  1538. bs.logout()
  1539. return jsonify({
  1540. 'success': False,
  1541. 'error': f'Failed to retrieve data, please enter a valid stock code'
  1542. }), 400
  1543. data_list = []
  1544. while rs.next():
  1545. data_list.append(rs.get_row_data())
  1546. # 登出系统
  1547. bs.logout()
  1548. columns = rs.fields
  1549. df = pd.DataFrame(data_list, columns=columns)
  1550. df = df.rename(columns={'time': 'timestamps'})
  1551. numeric_columns = ['timestamps','open', 'high', 'low', 'close', 'volume', 'amount']
  1552. for col in numeric_columns:
  1553. df[col] = pd.to_numeric(df[col], errors='coerce')
  1554. df['timestamps'] = pd.to_datetime(df['timestamps'].astype(str), format='%Y%m%d%H%M%S%f')
  1555. df = df.dropna()
  1556. data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
  1557. os.makedirs(data_dir, exist_ok=True)
  1558. filename = f"Stock_5min_A股.csv"
  1559. file_path = os.path.join(data_dir, filename)
  1560. df.to_csv(
  1561. file_path,
  1562. index = False,
  1563. encoding = 'utf-8',
  1564. mode = 'w'
  1565. )
  1566. data_files = load_data_files()
  1567. return jsonify({
  1568. 'success': True,
  1569. 'message': f'Stock data saved successfully: {filename}',
  1570. 'file_name': filename
  1571. })
  1572. except Exception as e:
  1573. return jsonify({
  1574. 'success': False,
  1575. 'error': f'Error processing stock data: {str(e)}'
  1576. }), 500
  1577. # 技术指标图表绘制接口
  1578. @app.route('/api/generate-chart', methods=['POST'])
  1579. def generate_chart():
  1580. try:
  1581. data = request.get_json()
  1582. # 验证参数
  1583. required_fields = ['file_path', 'lookback', 'diagram_type', 'historical_start_idx']
  1584. for field in required_fields:
  1585. if field not in data:
  1586. return jsonify({'success': False, 'error': f'Missing required field: {field}'}), 400
  1587. file_path = data['file_path']
  1588. lookback = int(data['lookback'])
  1589. diagram_type = data['diagram_type']
  1590. historical_start_idx = int(data['historical_start_idx'])
  1591. # 加载数据
  1592. df, error = load_data_file(file_path)
  1593. if error:
  1594. return jsonify({'success': False, 'error': error}), 400
  1595. if len(df) < lookback + historical_start_idx:
  1596. return jsonify({
  1597. 'success': False,
  1598. 'error': f'Insufficient data length, need at least {lookback + historical_start_idx} rows'
  1599. }), 400
  1600. pred_df = None
  1601. actual_df = None
  1602. # 生成图表
  1603. chart_json = create_technical_chart(
  1604. df=df,
  1605. pred_df=pred_df,
  1606. lookback=lookback,
  1607. pred_len=0,
  1608. diagram_type=diagram_type,
  1609. actual_df=actual_df,
  1610. historical_start_idx=historical_start_idx
  1611. )
  1612. # 表格数据
  1613. table_data_start = historical_start_idx
  1614. table_data_end = historical_start_idx + lookback
  1615. table_df = df.iloc[table_data_start:table_data_end]
  1616. table_data = table_df.to_dict('records')
  1617. return jsonify({
  1618. 'success': True,
  1619. 'chart': json.loads(chart_json),
  1620. 'table_data': table_data,
  1621. 'message': 'Technical chart generated successfully'
  1622. })
  1623. except Exception as e:
  1624. return jsonify({
  1625. 'success': False,
  1626. 'error': f'Failed to generate technical chart: {str(e)}'
  1627. }), 500
  1628. if __name__ == '__main__':
  1629. print("Starting Kronos Web UI...")
  1630. print(f"Model availability: {MODEL_AVAILABLE}")
  1631. if MODEL_AVAILABLE:
  1632. print("Tip: You can load Kronos model through /api/load-model endpoint")
  1633. else:
  1634. print("Tip: Will use simulated data for demonstration")
  1635. app.run(debug=True, host='0.0.0.0', port=7070)