You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1858 lines
76 KiB

import os
import pandas as pd
import numpy as np
import json
import plotly.graph_objects as go
import plotly.utils
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import sys
import warnings
from datetime import datetime
import baostock as bs
import re
warnings.filterwarnings('ignore')
# Add project root directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from model import Kronos, KronosTokenizer, KronosPredictor
MODEL_AVAILABLE = True
except ImportError:
MODEL_AVAILABLE = False
print("Warning: Kronos model cannot be imported, will use simulated data for demonstration")
app = Flask(__name__)
CORS(app)
# Global variables to store models
tokenizer = None
model = None
predictor = None
# 获取webui目录的路径
WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录(webui的父目录)
BASE_DIR = os.path.dirname(WEBUI_DIR)
AVAILABLE_MODELS = {
'kronos-mini': {
'name': 'Kronos-mini',
'model_id': os.path.join(BASE_DIR, 'models', 'Kronos-mini'),
'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
'context_length': 2048,
'params': '4.1M',
'description': '轻量级模型,适合快速预测'
},
'kronos-small': {
'name': 'Kronos-small',
'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-small'),
'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
'context_length': 512,
'params': '24.7M',
'description': '小型模型,平衡性能和速度'
},
'kronos-base': {
'name': 'Kronos-base',
'model_id': os.path.join(BASE_DIR, 'models', 'NeoQuasarKronos-base'),
'tokenizer_id': os.path.join(BASE_DIR, 'models', 'Kronos-Tokenizer-base'),
'context_length': 512,
'params': '102.3M',
'description': '基础模型,提供更好的预测质量'
}
}
# Available model configurations
# AVAILABLE_MODELS = {
# 'kronos-mini': {
# 'name': 'Kronos-mini',
# 'model_id': 'models/Kronos-mini', # 本地路径
# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
# 'context_length': 2048,
# 'params': '4.1M',
# 'description': '轻量级模型,适合快速预测'
# },
# 'kronos-small': {
# 'name': 'Kronos-small',
# 'model_id': 'models/NeoQuasarKronos-small', # 本地路径
# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
# 'context_length': 512,
# 'params': '24.7M',
# 'description': '小型模型,平衡性能和速度'
# },
# 'kronos-base': {
# 'name': 'Kronos-base',
# 'model_id': 'models/NeoQuasarKronos-base', # 本地路径
# 'tokenizer_id': 'models/Kronos-Tokenizer-base', # 本地路径
# 'context_length': 512,
# 'params': '102.3M',
# 'description': '基础模型,提供更好的预测质量'
# }
# }
def load_data_files():
"""Scan data directory and return available data files"""
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
data_files = []
if os.path.exists(data_dir):
for file in os.listdir(data_dir):
if file.endswith(('.csv', '.feather')):
file_path = os.path.join(data_dir, file)
file_size = os.path.getsize(file_path)
data_files.append({
'name': file,
'path': file_path,
'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
})
return data_files
def load_data_file(file_path):
"""Load data file"""
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.feather'):
df = pd.read_feather(file_path)
else:
return None, "Unsupported file format"
# Check required columns
required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols):
return None, f"Missing required columns: {required_cols}"
# Process timestamp column
if 'timestamps' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamps'])
elif 'timestamp' in df.columns:
df['timestamps'] = pd.to_datetime(df['timestamp'])
elif 'date' in df.columns:
# If column name is 'date', rename it to 'timestamps'
df['timestamps'] = pd.to_datetime(df['date'])
else:
# If no timestamp column exists, create one
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='1H')
# Ensure numeric columns are numeric type
for col in ['open', 'high', 'low', 'close']:
df[col] = pd.to_numeric(df[col], errors='coerce')
# Process volume column (optional)
if 'volume' in df.columns:
df['volume'] = pd.to_numeric(df['volume'], errors='coerce')
# Process amount column (optional, but not used for prediction)
if 'amount' in df.columns:
df['amount'] = pd.to_numeric(df['amount'], errors='coerce')
# Remove rows containing NaN values
df = df.dropna()
return df, None
except Exception as e:
return None, f"Failed to load file: {str(e)}"
def save_prediction_results(file_path, prediction_type, prediction_results, actual_data, input_data, prediction_params):
"""Save prediction results to file"""
try:
# Create prediction results directory
results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prediction_results')
os.makedirs(results_dir, exist_ok=True)
# Generate filename
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'prediction_{timestamp}.json'
filepath = os.path.join(results_dir, filename)
# Prepare data for saving
save_data = {
'timestamp': datetime.datetime.now().isoformat(),
'file_path': file_path,
'prediction_type': prediction_type,
'prediction_params': prediction_params,
'input_data_summary': {
'rows': len(input_data),
'columns': list(input_data.columns),
'price_range': {
'open': {'min': float(input_data['open'].min()), 'max': float(input_data['open'].max())},
'high': {'min': float(input_data['high'].min()), 'max': float(input_data['high'].max())},
'low': {'min': float(input_data['low'].min()), 'max': float(input_data['low'].max())},
'close': {'min': float(input_data['close'].min()), 'max': float(input_data['close'].max())}
},
'last_values': {
'open': float(input_data['open'].iloc[-1]),
'high': float(input_data['high'].iloc[-1]),
'low': float(input_data['low'].iloc[-1]),
'close': float(input_data['close'].iloc[-1])
}
},
'prediction_results': prediction_results,
'actual_data': actual_data,
'analysis': {}
}
# If actual data exists, perform comparison analysis
if actual_data and len(actual_data) > 0:
# Calculate continuity analysis
if len(prediction_results) > 0 and len(actual_data) > 0:
last_pred = prediction_results[0] # First prediction point
first_actual = actual_data[0] # First actual point
save_data['analysis']['continuity'] = {
'last_prediction': {
'open': last_pred['open'],
'high': last_pred['high'],
'low': last_pred['low'],
'close': last_pred['close']
},
'first_actual': {
'open': first_actual['open'],
'high': first_actual['high'],
'low': first_actual['low'],
'close': first_actual['close']
},
'gaps': {
'open_gap': abs(last_pred['open'] - first_actual['open']),
'high_gap': abs(last_pred['high'] - first_actual['high']),
'low_gap': abs(last_pred['low'] - first_actual['low']),
'close_gap': abs(last_pred['close'] - first_actual['close'])
},
'gap_percentages': {
'open_gap_pct': (abs(last_pred['open'] - first_actual['open']) / first_actual['open']) * 100,
'high_gap_pct': (abs(last_pred['high'] - first_actual['high']) / first_actual['high']) * 100,
'low_gap_pct': (abs(last_pred['low'] - first_actual['low']) / first_actual['low']) * 100,
'close_gap_pct': (abs(last_pred['close'] - first_actual['close']) / first_actual['close']) * 100
}
}
# Save to file
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"Prediction results saved to: {filepath}")
return filepath
except Exception as e:
print(f"Failed to save prediction results: {e}")
return None
# def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
# """Create prediction chart"""
#
# print(f"🔍 创建图表调试:")
# print(f" 历史数据: {len(df) if df is not None else 0} 行")
# print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
# print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
#
# # 确保数据不为空
# if pred_df is None or len(pred_df) == 0:
# print("⚠️ 警告: 预测数据为空!")
# # 创建空图表
# fig = go.Figure()
# fig.update_layout(title='No prediction data available')
# return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
#
# # 其余代码保持不变...
#
# # Use specified historical data start position, not always from the beginning of df
# if historical_start_idx + lookback + pred_len <= len(df):
# # Display lookback historical points + pred_len prediction points starting from specified position
# historical_df = df.iloc[historical_start_idx:historical_start_idx+lookback]
# prediction_range = range(historical_start_idx+lookback, historical_start_idx+lookback+pred_len)
# else:
# # If data is insufficient, adjust to maximum available range
# available_lookback = min(lookback, len(df) - historical_start_idx)
# available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
# historical_df = df.iloc[historical_start_idx:historical_start_idx+available_lookback]
# prediction_range = range(historical_start_idx+available_lookback, historical_start_idx+available_lookback+available_pred_len)
#
# # Create chart
# fig = go.Figure()
#
# # Add historical data (candlestick chart)
# fig.add_trace(go.Candlestick(
# x=historical_df['timestamps'] if 'timestamps' in historical_df.columns else historical_df.index,
# open=historical_df['open'],
# high=historical_df['high'],
# low=historical_df['low'],
# close=historical_df['close'],
# name='Historical Data (400 data points)',
# increasing_line_color='#26A69A',
# decreasing_line_color='#EF5350'
# ))
#
# # Add prediction data (candlestick chart)
# if pred_df is not None and len(pred_df) > 0:
# # Calculate prediction data timestamps - ensure continuity with historical data
# if 'timestamps' in df.columns and len(historical_df) > 0:
# # Start from the last timestamp of historical data, create prediction timestamps with the same time interval
# last_timestamp = historical_df['timestamps'].iloc[-1]
# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
#
# pred_timestamps = pd.date_range(
# start=last_timestamp + time_diff,
# periods=len(pred_df),
# freq=time_diff
# )
# else:
# # If no timestamps, use index
# pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
#
# fig.add_trace(go.Candlestick(
# x=pred_timestamps,
# open=pred_df['open'],
# high=pred_df['high'],
# low=pred_df['low'],
# close=pred_df['close'],
# name='Prediction Data (120 data points)',
# increasing_line_color='#66BB6A',
# decreasing_line_color='#FF7043'
# ))
#
# # Add actual data for comparison (if exists)
# if actual_df is not None and len(actual_df) > 0:
# # Actual data should be in the same time period as prediction data
# if 'timestamps' in df.columns:
# # Actual data should use the same timestamps as prediction data to ensure time alignment
# if 'pred_timestamps' in locals():
# actual_timestamps = pred_timestamps
# else:
# # If no prediction timestamps, calculate from the last timestamp of historical data
# if len(historical_df) > 0:
# last_timestamp = historical_df['timestamps'].iloc[-1]
# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
# actual_timestamps = pd.date_range(
# start=last_timestamp + time_diff,
# periods=len(actual_df),
# freq=time_diff
# )
# else:
# actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
# else:
# actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
#
# fig.add_trace(go.Candlestick(
# x=actual_timestamps,
# open=actual_df['open'],
# high=actual_df['high'],
# low=actual_df['low'],
# close=actual_df['close'],
# name='Actual Data (120 data points)',
# increasing_line_color='#FF9800',
# decreasing_line_color='#F44336'
# ))
#
# # Update layout
# fig.update_layout(
# title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
# xaxis_title='Time',
# yaxis_title='Price',
# template='plotly_white',
# height=600,
# showlegend=True
# )
#
# # Ensure x-axis time continuity
# if 'timestamps' in historical_df.columns:
# # Get all timestamps and sort them
# all_timestamps = []
# if len(historical_df) > 0:
# all_timestamps.extend(historical_df['timestamps'])
# if 'pred_timestamps' in locals():
# all_timestamps.extend(pred_timestamps)
# if 'actual_timestamps' in locals():
# all_timestamps.extend(actual_timestamps)
#
# if all_timestamps:
# all_timestamps = sorted(all_timestamps)
# fig.update_xaxes(
# range=[all_timestamps[0], all_timestamps[-1]],
# rangeslider_visible=False,
# type='date'
# )
#
# # 修改这一行:
# # return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
#
# # 改为:
# try:
# chart_json = fig.to_json()
# print(f"✅ 图表JSON序列化成功,长度: {len(chart_json)}")
# return chart_json
# except Exception as e:
# print(f"❌ 图表序列化失败: {e}")
# # 返回一个简单的错误图表
# error_fig = go.Figure()
# error_fig.update_layout(title='Chart Rendering Error')
# return error_fig.to_json()
def create_prediction_chart(df, pred_df, lookback, pred_len, actual_df=None, historical_start_idx=0):
"""Create prediction chart"""
print(f"🔍 创建图表调试:")
print(f" 历史数据: {len(df) if df is not None else 0} 行")
print(f" 预测数据: {len(pred_df) if pred_df is not None else 0} 行")
print(f" 实际数据: {len(actual_df) if actual_df is not None else 0} 行")
# 确保数据不为空
if pred_df is None or len(pred_df) == 0:
print("⚠️ 警告: 预测数据为空!")
# 创建空图表
fig = go.Figure()
fig.update_layout(title='No prediction data available')
return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
# Use specified historical data start position, not always from the beginning of df
if historical_start_idx + lookback + pred_len <= len(df):
# Display lookback historical points + pred_len prediction points starting from specified position
historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
prediction_range = range(historical_start_idx + lookback, historical_start_idx + lookback + pred_len)
else:
# If data is insufficient, adjust to maximum available range
available_lookback = min(lookback, len(df) - historical_start_idx)
available_pred_len = min(pred_len, max(0, len(df) - historical_start_idx - available_lookback))
historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
prediction_range = range(historical_start_idx + available_lookback,
historical_start_idx + available_lookback + available_pred_len)
# Create chart
fig = go.Figure()
# Add historical data (candlestick chart)
fig.add_trace(go.Candlestick(
x=historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
open=historical_df['open'].tolist(),
high=historical_df['high'].tolist(),
low=historical_df['low'].tolist(),
close=historical_df['close'].tolist(),
name='Historical Data (400 data points)',
increasing_line_color='#26A69A',
decreasing_line_color='#EF5350'
))
# Add prediction data (candlestick chart)
if pred_df is not None and len(pred_df) > 0:
# Calculate prediction data timestamps - ensure continuity with historical data
if 'timestamps' in df.columns and len(historical_df) > 0:
# Start from the last timestamp of historical data, create prediction timestamps with the same time interval
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(hours=1)
pred_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(pred_df),
freq=time_diff
)
else:
# If no timestamps, use index
pred_timestamps = range(len(historical_df), len(historical_df) + len(pred_df))
fig.add_trace(go.Candlestick(
x=pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps),
open=pred_df['open'].tolist(),
high=pred_df['high'].tolist(),
low=pred_df['low'].tolist(),
close=pred_df['close'].tolist(),
name='Prediction Data (120 data points)',
increasing_line_color='#66BB6A',
decreasing_line_color='#FF7043'
))
# Add actual data for comparison (if exists)
if actual_df is not None and len(actual_df) > 0:
# Actual data should be in the same time period as prediction data
if 'timestamps' in df.columns:
# Actual data should use the same timestamps as prediction data to ensure time alignment
if 'pred_timestamps' in locals():
actual_timestamps = pred_timestamps
else:
# If no prediction timestamps, calculate from the last timestamp of historical data
if len(historical_df) > 0:
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
hours=1)
actual_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=len(actual_df),
freq=time_diff
)
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Candlestick(
x=actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
open=actual_df['open'].tolist(),
high=actual_df['high'].tolist(),
low=actual_df['low'].tolist(),
close=actual_df['close'].tolist(),
name='Actual Data (120 data points)',
increasing_line_color='#FF9800',
decreasing_line_color='#F44336'
))
# Update layout
fig.update_layout(
title='Kronos Financial Prediction Results - 400 Historical Points + 120 Prediction Points vs 120 Actual Points',
xaxis_title='Time',
yaxis_title='Price',
template='plotly_white',
height=600,
showlegend=True
)
# Ensure x-axis time continuity
if 'timestamps' in historical_df.columns:
# Get all timestamps and sort them
all_timestamps = []
if len(historical_df) > 0:
all_timestamps.extend(historical_df['timestamps'].tolist())
if 'pred_timestamps' in locals():
all_timestamps.extend(
pred_timestamps.tolist() if hasattr(pred_timestamps, 'tolist') else list(pred_timestamps))
if 'actual_timestamps' in locals():
all_timestamps.extend(
actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps))
if all_timestamps:
all_timestamps = sorted(all_timestamps)
fig.update_xaxes(
range=[all_timestamps[0], all_timestamps[-1]],
rangeslider_visible=False,
type='date'
)
# return json.dumps(fig, cls=plotly.utils.PlotlyJSONEncoder)
try:
chart_json = fig.to_json()
print(f"✅ 图表数据序列化完成,长度: {len(chart_json)}")
return chart_json
except Exception as e:
print(f"❌ 图表序列化失败: {e}")
error_fig = go.Figure()
error_fig.update_layout(title='Chart Rendering Error')
return error_fig.to_json()
# 计算指标
def calculate_indicators(df):
indicators = {}
# 计算移动平均线 (MA)
indicators['ma5'] = df['close'].rolling(window=5).mean()
indicators['ma10'] = df['close'].rolling(window=10).mean()
indicators['ma20'] = df['close'].rolling(window=20).mean()
# 计算MACD
exp12 = df['close'].ewm(span=12, adjust=False).mean()
exp26 = df['close'].ewm(span=26, adjust=False).mean()
indicators['macd'] = exp12 - exp26
indicators['signal'] = indicators['macd'].ewm(span=9, adjust=False).mean()
indicators['macd_hist'] = indicators['macd'] - indicators['signal']
# 计算RSI
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
indicators['rsi'] = 100 - (100 / (1 + rs))
# 计算布林带
indicators['bb_mid'] = df['close'].rolling(window=20).mean()
indicators['bb_std'] = df['close'].rolling(window=20).std()
indicators['bb_upper'] = indicators['bb_mid'] + 2 * indicators['bb_std']
indicators['bb_lower'] = indicators['bb_mid'] - 2 * indicators['bb_std']
# 计算随机震荡指标
low_min = df['low'].rolling(window=14).min()
high_max = df['high'].rolling(window=14).max()
indicators['stoch_k'] = 100 * ((df['close'] - low_min) / (high_max - low_min))
indicators['stoch_d'] = indicators['stoch_k'].rolling(window=3).mean()
# 滚动窗口均值策略
indicators['rwms_window'] = 90
indicators['rwms_mean'] = df['close'].rolling(window=90).mean()
indicators['rwms_signal'] = (df['close'] > indicators['rwms_mean']).astype(int)
# 三重指数平均(TRIX)策略
# 计算收盘价的EMA
ema1 = df['close'].ewm(span=12, adjust=False).mean()
# 计算EMA的EMA
ema2 = ema1.ewm(span=12, adjust=False).mean()
# 计算EMA的EMA的EMA
ema3 = ema2.ewm(span=12, adjust=False).mean()
# 计算TRIX
indicators['trix'] = (ema3 - ema3.shift(1)) / ema3.shift(1) * 100
# 计算信号线
indicators['trix_signal'] = indicators['trix'].ewm(span=9, adjust=False).mean()
return indicators
# 创建图表
def create_technical_chart(df, pred_df, lookback, pred_len, diagram_type, actual_df=None, historical_start_idx=0):
print(f" 🔍 数据内容: {len(df) if df is not None else 0} 行")
print(f" 🔍 图表类型: {diagram_type}")
# 数据范围
if historical_start_idx + lookback <= len(df):
historical_df = df.iloc[historical_start_idx:historical_start_idx + lookback]
else:
available_lookback = min(lookback, len(df) - historical_start_idx)
historical_df = df.iloc[historical_start_idx:historical_start_idx + available_lookback]
# 计算指标
historical_indicators = calculate_indicators(historical_df)
fig = go.Figure()
# 成交量图表
if diagram_type == 'Volume Chart (VOL)':
fig.add_trace(go.Bar(
x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_df['volume'].tolist() if 'volume' in historical_df.columns else [],
name = 'Historical Volume',
marker_color='#42A5F5'
))
if actual_df is not None and len(actual_df) > 0 and 'volume' in actual_df.columns:
if 'timestamps' in df.columns and len(historical_df) > 0:
last_timestamp = historical_df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0] if len(df) > 1 else pd.Timedelta(
hours=1)
actual_timestamps = pd.date_range(start=last_timestamp + time_diff, periods=len(actual_df),freq=time_diff)
else:
actual_timestamps = range(len(historical_df), len(historical_df) + len(actual_df))
fig.add_trace(go.Bar(
x = actual_timestamps.tolist() if hasattr(actual_timestamps, 'tolist') else list(actual_timestamps),
y = actual_df['volume'].tolist(),
name = 'Actual Volume',
marker_color='#FF9800'
))
fig.update_layout(yaxis_title='Volume')
# 移动平均线
elif diagram_type == 'Moving Average (MA)':
fig.add_trace(go.Scatter(
x = historical_df['timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['ma5'],
name='MA5',
line=dict(color='#26A69A', width=1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['ma10'],
name = 'MA10',
line = dict(color = '#42A5F5', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['ma20'],
name = 'MA20',
line = dict(color = '#7E57C2', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_df['close'],
name = 'Close Price',
line = dict(color = '#212121', width = 1, dash = 'dash')
))
fig.update_layout(yaxis_title = 'Price')
# MACD指标
elif diagram_type == 'MACD Indicator (MACD)':
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['macd'],
name = 'MACD',
line = dict(color = '#26A69A', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['signal'],
name = 'Signal',
line = dict(color = '#EF5350', width = 1)
))
fig.add_trace(go.Bar(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['macd_hist'],
name = 'MACD Histogram',
marker_color = '#42A5F5'
))
# 零轴线
fig.add_hline(y = 0, line_dash = "dash", line_color = "gray")
fig.update_layout(yaxis_title = 'MACD')
# RSI指标
elif diagram_type == 'RSI Indicator (RSI)':
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['rsi'],
name = 'RSI',
line = dict(color = '#26A69A', width = 1)
))
# 超买超卖线
fig.add_hline(y = 70, line_dash = "dash", line_color = "red", name = 'Overbought')
fig.add_hline(y = 30, line_dash = "dash", line_color = "green", name = 'Oversold')
fig.update_layout(yaxis_title = 'RSI', yaxis_range = [0, 100])
# 布林带
elif diagram_type == 'Bollinger Bands (BB)':
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['bb_upper'],
name = 'Upper Band',
line = dict(color = '#EF5350', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['bb_mid'],
name = 'Middle Band (MA20)',
line = dict(color = '#42A5F5', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['bb_lower'],
name = 'Lower Band',
line = dict(color = '#26A69A', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_df['close'],
name = 'Close Price',
line = dict(color = '#212121', width = 1)
))
fig.update_layout(yaxis_title = 'Price')
# 随机震荡指标
elif diagram_type == 'Stochastic Oscillator (STOCH)':
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['stoch_k'],
name = '%K',
line = dict(color = '#26A69A', width = 1)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['stoch_d'],
name = '%D',
line = dict(color = '#EF5350', width = 1)
))
fig.add_hline(y = 80, line_dash = "dash", line_color = "red", name = 'Overbought')
fig.add_hline(y = 20, line_dash = "dash", line_color = "green", name = 'Oversold')
fig.update_layout(yaxis_title = 'Stochastic', yaxis_range = [0, 100])
# 滚动窗口均值策略
elif diagram_type == 'Rolling Window Mean Strategy':
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_df['close'],
name = 'Close Price',
line = dict(color = '#212121', width = 1.5)
))
fig.add_trace(go.Scatter(
x = historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y = historical_indicators['rwms_mean'],
name = f'Rolling Mean ({historical_indicators["rwms_window"]} periods)',
line = dict(color = '#42A5F5', width = 1.5, dash = 'dash')
))
buy_signals = historical_df[historical_indicators['rwms_signal'] == 1]
fig.add_trace(go.Scatter(
x = buy_signals['timestamps'].tolist() if 'timestamps' in buy_signals.columns else buy_signals.index.tolist(),
y = buy_signals['close'],
mode = 'markers',
name = 'Buy Signal',
marker = dict(color = '#26A69A', size = 8, symbol = 'triangle-up')
))
sell_signals = historical_df[historical_indicators['rwms_signal'] == 0]
fig.add_trace(go.Scatter(
x = sell_signals[
'timestamps'].tolist() if 'timestamps' in sell_signals.columns else sell_signals.index.tolist(),
y = sell_signals['close'],
mode = 'markers',
name = 'Sell Signal',
marker = dict(color = '#EF5350', size = 8, symbol = 'triangle-down')
))
fig.update_layout(
yaxis_title = 'Price',
title = f'Rolling Window Mean Strategy (Window Size: {historical_indicators["rwms_window"]})'
)
# TRIX指标图表
elif diagram_type == 'TRIX Indicator (TRIX)':
fig.add_trace(go.Scatter(
x=historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y=historical_indicators['trix'],
name='TRIX',
line=dict(color='#26A69A', width=1)
))
fig.add_trace(go.Scatter(
x=historical_df[
'timestamps'].tolist() if 'timestamps' in historical_df.columns else historical_df.index.tolist(),
y=historical_indicators['trix_signal'],
name='TRIX Signal',
line=dict(color='#EF5350', width=1)
))
fig.add_hline(y=0, line_dash="dash", line_color="gray")
fig.update_layout(
yaxis_title='TRIX (%)',
title='Triple Exponential Average (TRIX) Strategy'
)
# 布局设置
fig.update_layout(
title = f'{diagram_type} - Technical Indicator (Real Data Only)',
xaxis_title = 'Time',
template = 'plotly_white',
height = 400,
showlegend = True,
margin = dict(t = 50, b = 30)
)
if 'timestamps' in historical_df.columns:
all_timestamps = historical_df['timestamps'].tolist()
if actual_df is not None and len(actual_df) > 0 and 'timestamps' in df.columns:
if 'actual_timestamps' in locals():
all_timestamps.extend(actual_timestamps.tolist())
if all_timestamps:
all_timestamps = sorted(all_timestamps)
fig.update_xaxes(
range=[all_timestamps[0], all_timestamps[-1]],
rangeslider_visible=False,
type='date'
)
try:
chart_json = fig.to_json()
print(f"✅ 技术指标图表序列化完成,长度: {len(chart_json)}")
return chart_json
except Exception as e:
print(f"❌ 技术指标图表序列化失败: {e}")
error_fig = go.Figure()
error_fig.update_layout(title='Chart Rendering Error')
return error_fig.to_json()
@app.route('/')
def index():
"""Home page"""
return render_template('index.html')
@app.route('/api/data-files')
def get_data_files():
"""Get available data file list"""
data_files = load_data_files()
return jsonify(data_files)
@app.route('/api/load-data', methods=['POST'])
def load_data():
"""Load data file"""
try:
data = request.get_json()
file_path = data.get('file_path')
if not file_path:
return jsonify({'error': 'File path cannot be empty'}), 400
df, error = load_data_file(file_path)
if error:
return jsonify({'error': error}), 400
# Detect data time frequency
def detect_timeframe(df):
if len(df) < 2:
return "Unknown"
time_diffs = []
for i in range(1, min(10, len(df))): # Check first 10 time differences
diff = df['timestamps'].iloc[i] - df['timestamps'].iloc[i-1]
time_diffs.append(diff)
if not time_diffs:
return "Unknown"
# Calculate average time difference
avg_diff = sum(time_diffs, pd.Timedelta(0)) / len(time_diffs)
# Convert to readable format
if avg_diff < pd.Timedelta(minutes=1):
return f"{avg_diff.total_seconds():.0f} seconds"
elif avg_diff < pd.Timedelta(hours=1):
return f"{avg_diff.total_seconds() / 60:.0f} minutes"
elif avg_diff < pd.Timedelta(days=1):
return f"{avg_diff.total_seconds() / 3600:.0f} hours"
else:
return f"{avg_diff.days} days"
# Return data information
data_info = {
'rows': len(df),
'columns': list(df.columns),
'start_date': df['timestamps'].min().isoformat() if 'timestamps' in df.columns else 'N/A',
'end_date': df['timestamps'].max().isoformat() if 'timestamps' in df.columns else 'N/A',
'price_range': {
'min': float(df[['open', 'high', 'low', 'close']].min().min()),
'max': float(df[['open', 'high', 'low', 'close']].max().max())
},
'prediction_columns': ['open', 'high', 'low', 'close'] + (['volume'] if 'volume' in df.columns else []),
'timeframe': detect_timeframe(df)
}
return jsonify({
'success': True,
'data_info': data_info,
'message': f'Successfully loaded data, total {len(df)} rows'
})
except Exception as e:
return jsonify({'error': f'Failed to load data: {str(e)}'}), 500
# @app.route('/api/predict', methods=['POST'])
# def predict():
# """Perform prediction"""
# try:
# data = request.get_json()
# file_path = data.get('file_path')
# lookback = int(data.get('lookback', 400))
# pred_len = int(data.get('pred_len', 120))
#
# # Get prediction quality parameters
# temperature = float(data.get('temperature', 1.0))
# top_p = float(data.get('top_p', 0.9))
# sample_count = int(data.get('sample_count', 1))
#
# if not file_path:
# return jsonify({'error': 'File path cannot be empty'}), 400
#
# # Load data
# df, error = load_data_file(file_path)
# if error:
# return jsonify({'error': error}), 400
#
# if len(df) < lookback:
# return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
#
# # Perform prediction
# if MODEL_AVAILABLE and predictor is not None:
# try:
# # Use real Kronos model
# # Only use necessary columns: OHLCV, excluding amount
# required_cols = ['open', 'high', 'low', 'close']
# if 'volume' in df.columns:
# required_cols.append('volume')
#
# # Process time period selection
# start_date = data.get('start_date')
#
# if start_date:
# # Custom time period - fix logic: use data within selected window
# start_dt = pd.to_datetime(start_date)
#
# # Find data after start time
# mask = df['timestamps'] >= start_dt
# time_range_df = df[mask]
#
# # Ensure sufficient data: lookback + pred_len
# if len(time_range_df) < lookback + pred_len:
# return jsonify({'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
#
# # Use first lookback data points within selected window for prediction
# x_df = time_range_df.iloc[:lookback][required_cols]
# x_timestamp = time_range_df.iloc[:lookback]['timestamps']
#
# # Use last pred_len data points within selected window as actual values
# y_timestamp = time_range_df.iloc[lookback:lookback+pred_len]['timestamps']
#
# # Calculate actual time period length
# start_timestamp = time_range_df['timestamps'].iloc[0]
# end_timestamp = time_range_df['timestamps'].iloc[lookback+pred_len-1]
# time_span = end_timestamp - start_timestamp
#
# prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
# else:
# # Use latest data
# x_df = df.iloc[:lookback][required_cols]
# x_timestamp = df.iloc[:lookback]['timestamps']
# y_timestamp = df.iloc[lookback:lookback+pred_len]['timestamps']
# prediction_type = "Kronos model prediction (latest data)"
#
# # Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
# if isinstance(x_timestamp, pd.DatetimeIndex):
# x_timestamp = pd.Series(x_timestamp, name='timestamps')
# if isinstance(y_timestamp, pd.DatetimeIndex):
# y_timestamp = pd.Series(y_timestamp, name='timestamps')
#
# # # 在 pred_df = predictor.predict(...) 之前添加:
# # print("🔍 调试预测输入:")
# # print(f"x_df 类型: {type(x_df)}")
# # print(f"x_df 形状: {x_df.shape}")
# # print(f"x_df 列名: {x_df.columns.tolist()}")
# # print(f"x_df 数据类型: {x_df.dtypes}")
# #
# # print(f"x_timestamp 类型: {type(x_timestamp)}")
# # print(f"x_timestamp 长度: {len(x_timestamp)}")
# #
# # print(f"y_timestamp 类型: {type(y_timestamp)}")
# # print(f"y_timestamp 长度: {len(y_timestamp)}")
# #
# # # 检查数据内容
# # print("x_df 前5行:")
# # print(x_df.head())
# #
# # # 在调用 predict 前确保数据格式正确
# # print(f"x_df 实际形状: {x_df.shape}") # 确认是 (400, 5)
# # print(f"x_df 数值类型: {x_df.values.dtype}")
# #
# # # 确保没有隐藏的索引列
# # x_df_clean = x_df.reset_index(drop=True)
# # print(f"重置索引后形状: {x_df_clean.shape}")
# #
# # # 在调用 predict 之前添加更详细的调试
# # print("🔍 深入调试 KronosPredictor:")
# #
# # # 检查 predictor 的属性
# # print(f"predictor 类型: {type(predictor)}")
# # print(f"predictor 设备: {getattr(predictor, 'device', 'unknown')}")
# # print(f"predictor max_context: {getattr(predictor, 'max_context', 'unknown')}")
# #
# # # 检查模型输入维度
# # if hasattr(predictor, 'model'):
# # model = predictor.model
# # print(f"模型参数示例:")
# # for name, param in model.named_parameters():
# # if 'weight' in name and param.dim() == 2:
# # print(f" {name}: {param.shape}")
# # break
# #
# # # 尝试手动准备数据
# # try:
# # # 将数据转换为 tensor 看看维度
# # import torch
# # x_tensor = torch.tensor(x_df.values, dtype=torch.float32)
# # print(f"Tensor 形状: {x_tensor.shape}")
# #
# # # 检查 tokenizer 的输入维度
# # if hasattr(predictor, 'tokenizer'):
# # tokenizer = predictor.tokenizer
# # print(f"tokenizer 输入维度: {getattr(tokenizer, 'd_in', 'unknown')}")
# #
# # except Exception as e:
# # print(f"Tensor 转换错误: {e}")
# #
# # # 在 predict 调用前测试 tokenizer
# # try:
# # # 测试 tokenizer 是否能正确处理数据
# # test_data = x_df.values # (400, 5)
# # print(f"测试数据形状: {test_data.shape}")
# #
# # # 尝试手动调用 tokenizer
# # if hasattr(predictor.tokenizer, 'encode'):
# # encoded = predictor.tokenizer.encode(test_data)
# # print(f"Tokenized 数据形状: {encoded.shape}")
# # else:
# # print("Tokenizer 没有 encode 方法")
# #
# # except Exception as e:
# # print(f"Tokenizer 测试错误: {e}")
#
# pred_df = predictor.predict(
# df=x_df,
# x_timestamp=x_timestamp,
# y_timestamp=y_timestamp,
# pred_len=pred_len,
# T=temperature,
# top_p=top_p,
# sample_count=sample_count
# )
#
# except Exception as e:
# return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
# else:
# return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
#
# # Prepare actual data for comparison (if exists)
# actual_data = []
# actual_df = None
#
# if start_date: # Custom time period
# # Fix logic: use data within selected window
# # Prediction uses first 400 data points within selected window
# # Actual data should be last 120 data points within selected window
# start_dt = pd.to_datetime(start_date)
#
# # Find data starting from start_date
# mask = df['timestamps'] >= start_dt
# time_range_df = df[mask]
#
# if len(time_range_df) >= lookback + pred_len:
# # Get last 120 data points within selected window as actual values
# actual_df = time_range_df.iloc[lookback:lookback+pred_len]
#
# for i, (_, row) in enumerate(actual_df.iterrows()):
# actual_data.append({
# 'timestamp': row['timestamps'].isoformat(),
# 'open': float(row['open']),
# 'high': float(row['high']),
# 'low': float(row['low']),
# 'close': float(row['close']),
# 'volume': float(row['volume']) if 'volume' in row else 0,
# 'amount': float(row['amount']) if 'amount' in row else 0
# })
# else: # Latest data
# # Prediction uses first 400 data points
# # Actual data should be 120 data points after first 400 data points
# if len(df) >= lookback + pred_len:
# actual_df = df.iloc[lookback:lookback+pred_len]
# for i, (_, row) in enumerate(actual_df.iterrows()):
# actual_data.append({
# 'timestamp': row['timestamps'].isoformat(),
# 'open': float(row['open']),
# 'high': float(row['high']),
# 'low': float(row['low']),
# 'close': float(row['close']),
# 'volume': float(row['volume']) if 'volume' in row else 0,
# 'amount': float(row['amount']) if 'amount' in row else 0
# })
#
# # Create chart - pass historical data start position
# if start_date:
# # Custom time period: find starting position of historical data in original df
# start_dt = pd.to_datetime(start_date)
# mask = df['timestamps'] >= start_dt
# historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
# else:
# # Latest data: start from beginning
# historical_start_idx = 0
#
# chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
#
# # Prepare prediction result data - fix timestamp calculation logic
# if 'timestamps' in df.columns:
# if start_date:
# # Custom time period: use selected window data to calculate timestamps
# start_dt = pd.to_datetime(start_date)
# mask = df['timestamps'] >= start_dt
# time_range_df = df[mask]
#
# if len(time_range_df) >= lookback:
# # Calculate prediction timestamps starting from last time point of selected window
# last_timestamp = time_range_df['timestamps'].iloc[lookback-1]
# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
# future_timestamps = pd.date_range(
# start=last_timestamp + time_diff,
# periods=pred_len,
# freq=time_diff
# )
# else:
# future_timestamps = []
# else:
# # Latest data: calculate from last time point of entire data file
# last_timestamp = df['timestamps'].iloc[-1]
# time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
# future_timestamps = pd.date_range(
# start=last_timestamp + time_diff,
# periods=pred_len,
# freq=time_diff
# )
# else:
# future_timestamps = range(len(df), len(df) + pred_len)
#
# prediction_results = []
# for i, (_, row) in enumerate(pred_df.iterrows()):
# prediction_results.append({
# 'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
# 'open': float(row['open']),
# 'high': float(row['high']),
# 'low': float(row['low']),
# 'close': float(row['close']),
# 'volume': float(row['volume']) if 'volume' in row else 0,
# 'amount': float(row['amount']) if 'amount' in row else 0
# })
#
# # Save prediction results to file
# try:
# save_prediction_results(
# file_path=file_path,
# prediction_type=prediction_type,
# prediction_results=prediction_results,
# actual_data=actual_data,
# input_data=x_df,
# prediction_params={
# 'lookback': lookback,
# 'pred_len': pred_len,
# 'temperature': temperature,
# 'top_p': top_p,
# 'sample_count': sample_count,
# 'start_date': start_date if start_date else 'latest'
# }
# )
# except Exception as e:
# print(f"Failed to save prediction results: {e}")
#
# return jsonify({
# 'success': True,
# 'prediction_type': prediction_type,
# 'chart': chart_json,
# 'prediction_results': prediction_results,
# 'actual_data': actual_data,
# 'has_comparison': len(actual_data) > 0,
# 'message': f'Prediction completed, generated {pred_len} prediction points' + (f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
# })
#
# except Exception as e:
# return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
@app.route('/api/predict', methods=['POST'])
def predict():
"""Perform prediction"""
try:
data = request.get_json()
file_path = data.get('file_path')
lookback = int(data.get('lookback', 400))
pred_len = int(data.get('pred_len', 120))
# Get prediction quality parameters
temperature = float(data.get('temperature', 1.0))
top_p = float(data.get('top_p', 0.9))
sample_count = int(data.get('sample_count', 1))
if not file_path:
return jsonify({'error': 'File path cannot be empty'}), 400
# Load data
df, error = load_data_file(file_path)
if error:
return jsonify({'error': error}), 400
if len(df) < lookback:
return jsonify({'error': f'Insufficient data length, need at least {lookback} rows'}), 400
# Perform prediction
if MODEL_AVAILABLE and predictor is not None:
try:
# Use real Kronos model
# Only use necessary columns: OHLCV + amount
required_cols = ['open', 'high', 'low', 'close', 'volume', 'amount']
# Process time period selection
start_date = data.get('start_date')
if start_date:
# Custom time period - fix logic: use data within selected window
start_dt = pd.to_datetime(start_date)
# Find data after start time
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
# Ensure sufficient data: lookback + pred_len
if len(time_range_df) < lookback + pred_len:
return jsonify({
'error': f'Insufficient data from start time {start_dt.strftime("%Y-%m-%d %H:%M")}, need at least {lookback + pred_len} data points, currently only {len(time_range_df)} available'}), 400
# Use first lookback data points within selected window for prediction
x_df = time_range_df.iloc[:lookback][required_cols]
x_timestamp = time_range_df.iloc[:lookback]['timestamps']
# Use last pred_len data points within selected window as actual values
y_timestamp = time_range_df.iloc[lookback:lookback + pred_len]['timestamps']
# Calculate actual time period length
start_timestamp = time_range_df['timestamps'].iloc[0]
end_timestamp = time_range_df['timestamps'].iloc[lookback + pred_len - 1]
time_span = end_timestamp - start_timestamp
prediction_type = f"Kronos model prediction (within selected window: first {lookback} data points for prediction, last {pred_len} data points for comparison, time span: {time_span})"
else:
# Use latest data
x_df = df.iloc[:lookback][required_cols]
x_timestamp = df.iloc[:lookback]['timestamps']
y_timestamp = df.iloc[lookback:lookback + pred_len]['timestamps']
prediction_type = "Kronos model prediction (latest data)"
# Debug information
print(f"🔍 传递给predictor的数据列: {x_df.columns.tolist()}")
print(f"🔍 数据形状: {x_df.shape}")
print(f"🔍 数据样例:")
print(x_df.head(2))
# Ensure timestamps are Series format, not DatetimeIndex, to avoid .dt attribute error in Kronos model
if isinstance(x_timestamp, pd.DatetimeIndex):
x_timestamp = pd.Series(x_timestamp, name='timestamps')
if isinstance(y_timestamp, pd.DatetimeIndex):
y_timestamp = pd.Series(y_timestamp, name='timestamps')
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=temperature,
top_p=top_p,
sample_count=sample_count
)
except Exception as e:
return jsonify({'error': f'Kronos model prediction failed: {str(e)}'}), 500
else:
return jsonify({'error': 'Kronos model not loaded, please load model first'}), 400
# Prepare actual data for comparison (if exists)
actual_data = []
actual_df = None
if start_date: # Custom time period
# Fix logic: use data within selected window
# Prediction uses first 400 data points within selected window
# Actual data should be last 120 data points within selected window
start_dt = pd.to_datetime(start_date)
# Find data starting from start_date
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback + pred_len:
# Get last 120 data points within selected window as actual values
actual_df = time_range_df.iloc[lookback:lookback + pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
else: # Latest data
# Prediction uses first 400 data points
# Actual data should be 120 data points after first 400 data points
if len(df) >= lookback + pred_len:
actual_df = df.iloc[lookback:lookback + pred_len]
for i, (_, row) in enumerate(actual_df.iterrows()):
actual_data.append({
'timestamp': row['timestamps'].isoformat(),
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Create chart - pass historical data start position
if start_date:
# Custom time period: find starting position of historical data in original df
start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt
historical_start_idx = df[mask].index[0] if len(df[mask]) > 0 else 0
else:
# Latest data: start from beginning
historical_start_idx = 0
chart_json = create_prediction_chart(df, pred_df, lookback, pred_len, actual_df, historical_start_idx)
# Prepare prediction result data - fix timestamp calculation logic
if 'timestamps' in df.columns:
if start_date:
# Custom time period: use selected window data to calculate timestamps
start_dt = pd.to_datetime(start_date)
mask = df['timestamps'] >= start_dt
time_range_df = df[mask]
if len(time_range_df) >= lookback:
# Calculate prediction timestamps starting from last time point of selected window
last_timestamp = time_range_df['timestamps'].iloc[lookback - 1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = []
else:
# Latest data: calculate from last time point of entire data file
last_timestamp = df['timestamps'].iloc[-1]
time_diff = df['timestamps'].iloc[1] - df['timestamps'].iloc[0]
future_timestamps = pd.date_range(
start=last_timestamp + time_diff,
periods=pred_len,
freq=time_diff
)
else:
future_timestamps = range(len(df), len(df) + pred_len)
prediction_results = []
for i, (_, row) in enumerate(pred_df.iterrows()):
prediction_results.append({
'timestamp': future_timestamps[i].isoformat() if i < len(future_timestamps) else f"T{i}",
'open': float(row['open']),
'high': float(row['high']),
'low': float(row['low']),
'close': float(row['close']),
'volume': float(row['volume']) if 'volume' in row else 0,
'amount': float(row['amount']) if 'amount' in row else 0
})
# Save prediction results to file
try:
save_prediction_results(
file_path=file_path,
prediction_type=prediction_type,
prediction_results=prediction_results,
actual_data=actual_data,
input_data=x_df,
prediction_params={
'lookback': lookback,
'pred_len': pred_len,
'temperature': temperature,
'top_p': top_p,
'sample_count': sample_count,
'start_date': start_date if start_date else 'latest'
}
)
except Exception as e:
print(f"Failed to save prediction results: {e}")
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# 在返回前添加
print(f"✅ 预测完成,返回数据:")
print(f" 成功: {True}")
print(f" 预测类型: {prediction_type}")
print(f" 图表数据长度: {len(chart_json)}")
print(f" 预测结果数量: {len(prediction_results)}")
print(f" 实际数据数量: {len(actual_data)}")
print(f" 有比较数据: {len(actual_data) > 0}")
return jsonify({
'success': True,
'prediction_type': prediction_type,
'chart': chart_json,
'prediction_results': prediction_results,
'actual_data': actual_data,
'has_comparison': len(actual_data) > 0,
'message': f'Prediction completed, generated {pred_len} prediction points' + (
f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
})
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# return jsonify({
# 'success': True,
# 'prediction_type': prediction_type,
# 'chart': chart_json,
# 'prediction_results': prediction_results,
# 'actual_data': actual_data,
# 'has_comparison': len(actual_data) > 0,
# 'message': f'Prediction completed, generated {pred_len} prediction points' + (
# f', including {len(actual_data)} actual data points for comparison' if len(actual_data) > 0 else '')
# })
except Exception as e:
return jsonify({'error': f'Prediction failed: {str(e)}'}), 500
# @app.route('/api/load-model', methods=['POST'])
# def load_model():
# """Load Kronos model"""
# global tokenizer, model, predictor
#
# try:
# if not MODEL_AVAILABLE:
# return jsonify({'error': 'Kronos model library not available'}), 400
#
# data = request.get_json()
# model_key = data.get('model_key', 'kronos-small')
# device = data.get('device', 'cpu')
#
# if model_key not in AVAILABLE_MODELS:
# return jsonify({'error': f'Unsupported model: {model_key}'}), 400
#
# model_config = AVAILABLE_MODELS[model_key]
#
# # Load tokenizer and model
# tokenizer = KronosTokenizer.from_pretrained(model_config['tokenizer_id'])
# model = Kronos.from_pretrained(model_config['model_id'])
#
# # Create predictor
# predictor = KronosPredictor(model, tokenizer, device=device, max_context=model_config['context_length'])
#
# return jsonify({
# 'success': True,
# 'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
# 'model_info': {
# 'name': model_config['name'],
# 'params': model_config['params'],
# 'context_length': model_config['context_length'],
# 'description': model_config['description']
# }
# })
#
# except Exception as e:
# return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/load-model', methods=['POST'])
def load_model():
global tokenizer, model, predictor
try:
if not MODEL_AVAILABLE:
return jsonify({'error': 'Kronos model library not available'}), 400
data = request.get_json()
model_key = data.get('model_key', 'kronos-small')
device = data.get('device', 'cpu')
if model_key not in AVAILABLE_MODELS:
return jsonify({'error': f'Unsupported model: {model_key}'}), 400
model_config = AVAILABLE_MODELS[model_key]
print(f"🚀 Loading model from: {model_config['model_id']}")
model_path = model_config['model_id']
tokenizer_path = model_config['tokenizer_id']
if os.path.exists(model_path):
model_files = os.listdir(model_path)
print(f"📄 模型目录中的文件: {model_files}")
# 检查模型路径是否存在
if not os.path.exists(model_path):
return jsonify({'error': f'模型路径不存在: {model_path}'}), 400
try:
# 直接从本地加载模型
model = Kronos.from_pretrained(
model_config['model_id'],
local_files_only=True
)
# 读取模型配置文件获取正确参数
config_path = os.path.join(model_config['model_id'], 'config.json')
if os.path.exists(config_path):
print(f"读取配置文件: {config_path}")
with open(config_path, 'r') as f:
config = json.load(f)
for key, value in config.items():
print(f" {key}: {value}")
# 使用配置中的参数创建tokenizer
tokenizer = KronosTokenizer(
d_in=6, # OHLC + volume
d_model=config['d_model'], # 832
n_heads=config['n_heads'], # 16
ff_dim=config['ff_dim'], # 2048
n_enc_layers=config['n_layers'], # 12
n_dec_layers=config['n_layers'], # 12
ffn_dropout_p=config['ffn_dropout_p'], # 0.2
attn_dropout_p=config['attn_dropout_p'], # 0.0
resid_dropout_p=config['resid_dropout_p'], # 0.2
s1_bits=config['s1_bits'], # 10
s2_bits=config['s2_bits'], # 10
beta=1.0,
gamma0=1.0,
gamma=1.0,
zeta=1.0,
group_size=1
)
else:
return jsonify({'error': f'Config file not found: {config_path}'}), 400
except Exception as e:
return jsonify({'error': f'Failed to load model: {str(e)}'}), 500
# 创建predictor
predictor = KronosPredictor(
model,
tokenizer,
device=device,
max_context=model_config['context_length']
)
return jsonify({
'success': True,
'message': f'Model loaded successfully: {model_config["name"]} ({model_config["params"]}) on {device}',
'model_info': {
**model_config,
'model_path': model_config['model_id'],
'abs_model_path': os.path.abspath(model_config['model_id']),
'device': device
}
})
except Exception as e:
import traceback
print("【API接口错误】")
print(f"错误类型: {type(e).__name__}")
print(f"错误信息: {str(e)}")
traceback.print_exc()
print("=" * 60)
return jsonify({'error': f'Model loading failed: {str(e)}'}), 500
@app.route('/api/available-models')
def get_available_models():
"""Get available model list"""
return jsonify({
'models': AVAILABLE_MODELS,
'model_available': MODEL_AVAILABLE
})
@app.route('/api/model-status')
def get_model_status():
"""Get model status"""
if MODEL_AVAILABLE:
if predictor is not None:
return jsonify({
'available': True,
'loaded': True,
'message': 'Kronos model loaded and available',
'current_model': {
'name': predictor.model.__class__.__name__,
'device': str(next(predictor.model.parameters()).device)
}
})
else:
return jsonify({
'available': True,
'loaded': False,
'message': 'Kronos model available but not loaded'
})
else:
return jsonify({
'available': False,
'loaded': False,
'message': 'Kronos model library not available, please install related dependencies'
})
@app.route('/api/stock-data', methods=['POST'])
def Stock_Data():
try:
data = request.get_json()
stock_code = data.get('stock_code', '').strip()
# 股票代码不能为空
if not stock_code:
return jsonify({
'success': False,
'error': f'Stock code cannot be empty'
}), 400
# 股票代码格式验证
if not re.match(r'^[a-z]+\.\d+$', stock_code):
return jsonify({
'success': False,
'error': f'The stock code you entered is invalid'
}), 400
# 登录 baostock
lg = bs.login()
if lg.error_code != '0':
return jsonify({
'success': False,
'error': f'Login failed: {lg.error_msg}'
}), 400
end_date = datetime.now().strftime('%Y-%m-%d')
rs = bs.query_history_k_data_plus(
stock_code,
"time,open,high,low,close,volume,amount",
start_date = '2024-06-01',
end_date = end_date,
frequency = "5",
adjustflag = "3"
)
# 检查获取结果
if rs.error_code != '0':
bs.logout()
return jsonify({
'success': False,
'error': f'Failed to retrieve data, please enter a valid stock code'
}), 400
# 提取数据
data_list = []
while rs.next():
data_list.append(rs.get_row_data())
# 登出系统
bs.logout()
columns = rs.fields
df = pd.DataFrame(data_list, columns=columns)
# 数值列转换
df = df.rename(columns={'time': 'timestamps'})
numeric_columns = ['timestamps','open', 'high', 'low', 'close', 'volume', 'amount']
for col in numeric_columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['timestamps'] = pd.to_datetime(df['timestamps'].astype(str), format='%Y%m%d%H%M%S%f')
# 去除无效数据
df = df.dropna()
# 保存
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
os.makedirs(data_dir, exist_ok=True)
filename = f"Stock_5min_A股.csv"
file_path = os.path.join(data_dir, filename)
df.to_csv(
file_path,
index = False,
encoding = 'utf-8',
mode = 'w'
)
data_files = load_data_files()
return jsonify({
'success': True,
'message': f'Stock data saved successfully: {filename}',
'file_name': filename
})
except Exception as e:
return jsonify({
'success': False,
'error': f'Error processing stock data: {str(e)}'
}), 500
@app.route('/api/generate-chart', methods=['POST'])
def generate_chart():
try:
data = request.get_json()
# 验证参数
required_fields = ['file_path', 'lookback', 'diagram_type', 'historical_start_idx']
for field in required_fields:
if field not in data:
return jsonify({'success': False, 'error': f'Missing required field: {field}'}), 400
# 解析参数
file_path = data['file_path']
lookback = int(data['lookback'])
diagram_type = data['diagram_type']
historical_start_idx = int(data['historical_start_idx'])
# 加载数据
df, error = load_data_file(file_path)
if error:
return jsonify({'success': False, 'error': error}), 400
if len(df) < lookback + historical_start_idx:
return jsonify({
'success': False,
'error': f'Insufficient data length, need at least {lookback + historical_start_idx} rows'
}), 400
pred_df = None
actual_df = None
# 生成图表
chart_json = create_technical_chart(
df=df,
pred_df=pred_df,
lookback=lookback,
pred_len=0,
diagram_type=diagram_type,
actual_df=actual_df,
historical_start_idx=historical_start_idx
)
# 表格数据
table_data_start = historical_start_idx
table_data_end = historical_start_idx + lookback
table_df = df.iloc[table_data_start:table_data_end]
table_data = table_df.to_dict('records')
return jsonify({
'success': True,
'chart': json.loads(chart_json),
'table_data': table_data,
'message': 'Technical chart generated successfully'
})
except Exception as e:
return jsonify({
'success': False,
'error': f'Failed to generate technical chart: {str(e)}'
}), 500
if __name__ == '__main__':
print("Starting Kronos Web UI...")
print(f"Model availability: {MODEL_AVAILABLE}")
if MODEL_AVAILABLE:
print("Tip: You can load Kronos model through /api/load-model endpoint")
else:
print("Tip: Will use simulated data for demonstration")
app.run(debug=True, host='0.0.0.0', port=7070)