You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
5.4 KiB

1 month ago
  1. import os
  2. import pickle
  3. import numpy as np
  4. import pandas as pd
  5. import qlib
  6. from qlib.config import REG_CN
  7. from qlib.data import D
  8. from qlib.data.dataset.loader import QlibDataLoader
  9. from tqdm import trange
  10. from config import Config
  11. class QlibDataPreprocessor:
  12. """
  13. A class to handle the loading, processing, and splitting of Qlib financial data.
  14. """
  15. def __init__(self):
  16. """Initializes the preprocessor with configuration and data fields."""
  17. self.config = Config()
  18. self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
  19. self.data = {} # A dictionary to store processed data for each symbol.
  20. def initialize_qlib(self):
  21. """Initializes the Qlib environment."""
  22. print("Initializing Qlib...")
  23. qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
  24. def load_qlib_data(self):
  25. """
  26. Loads raw data from Qlib, processes it symbol by symbol, and stores
  27. it in the `self.data` attribute.
  28. """
  29. print("Loading and processing data from Qlib...")
  30. data_fields_qlib = ['$' + f for f in self.data_fields]
  31. cal: np.ndarray = D.calendar()
  32. # Determine the actual start and end times to load, including buffer for lookback and predict windows.
  33. start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time))
  34. end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time))
  35. # Check if start_index lookbackw_window will cause negative index
  36. adjusted_start_index = max(start_index - self.config.lookback_window, 0)
  37. real_start_time = cal[adjusted_start_index]
  38. # Check if end_index exceeds the range of the array
  39. if end_index >= len(cal):
  40. end_index = len(cal) - 1
  41. elif cal[end_index] != pd.Timestamp(self.config.dataset_end_time):
  42. end_index -= 1
  43. # Check if end_index+predictw_window will exceed the range of the array
  44. adjusted_end_index = min(end_index + self.config.predict_window, len(cal) - 1)
  45. real_end_time = cal[adjusted_end_index]
  46. # Load data using Qlib's data loader.
  47. data_df = QlibDataLoader(config=data_fields_qlib).load(
  48. self.config.instrument, real_start_time, real_end_time
  49. )
  50. data_df = data_df.stack().unstack(level=1) # Reshape for easier access.
  51. symbol_list = list(data_df.columns)
  52. for i in trange(len(symbol_list), desc="Processing Symbols"):
  53. symbol = symbol_list[i]
  54. symbol_df = data_df[symbol]
  55. # Pivot the table to have features as columns and datetime as index.
  56. symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'})
  57. symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol)
  58. symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields})
  59. # Calculate amount and select final features.
  60. symbol_df['vol'] = symbol_df['volume']
  61. symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol']
  62. symbol_df = symbol_df[self.config.feature_list]
  63. # Filter out symbols with insufficient data.
  64. symbol_df = symbol_df.dropna()
  65. if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1:
  66. continue
  67. self.data[symbol] = symbol_df
  68. def prepare_dataset(self):
  69. """
  70. Splits the loaded data into train, validation, and test sets and saves them to disk.
  71. """
  72. print("Splitting data into train, validation, and test sets...")
  73. train_data, val_data, test_data = {}, {}, {}
  74. symbol_list = list(self.data.keys())
  75. for i in trange(len(symbol_list), desc="Preparing Datasets"):
  76. symbol = symbol_list[i]
  77. symbol_df = self.data[symbol]
  78. # Define time ranges from config.
  79. train_start, train_end = self.config.train_time_range
  80. val_start, val_end = self.config.val_time_range
  81. test_start, test_end = self.config.test_time_range
  82. # Create boolean masks for each dataset split.
  83. train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end)
  84. val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end)
  85. test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end)
  86. # Apply masks to create the final datasets.
  87. train_data[symbol] = symbol_df[train_mask]
  88. val_data[symbol] = symbol_df[val_mask]
  89. test_data[symbol] = symbol_df[test_mask]
  90. # Save the datasets using pickle.
  91. os.makedirs(self.config.dataset_path, exist_ok=True)
  92. with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f:
  93. pickle.dump(train_data, f)
  94. with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f:
  95. pickle.dump(val_data, f)
  96. with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f:
  97. pickle.dump(test_data, f)
  98. print("Datasets prepared and saved successfully.")
  99. if __name__ == '__main__':
  100. # This block allows the script to be run directly to perform data preprocessing.
  101. preprocessor = QlibDataPreprocessor()
  102. preprocessor.initialize_qlib()
  103. preprocessor.load_qlib_data()
  104. preprocessor.prepare_dataset()