You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
5.5 KiB

1 month ago
  1. import pickle
  2. import random
  3. import numpy as np
  4. import torch
  5. from torch.utils.data import Dataset
  6. from config import Config
  7. class QlibDataset(Dataset):
  8. """
  9. A PyTorch Dataset for handling Qlib financial time series data.
  10. This dataset pre-computes all possible start indices for sliding windows
  11. and then randomly samples from them during training/validation.
  12. Args:
  13. data_type (str): The type of dataset to load, either 'train' or 'val'.
  14. Raises:
  15. ValueError: If `data_type` is not 'train' or 'val'.
  16. """
  17. def __init__(self, data_type: str = 'train'):
  18. self.config = Config()
  19. if data_type not in ['train', 'val']:
  20. raise ValueError("data_type must be 'train' or 'val'")
  21. self.data_type = data_type
  22. # Use a dedicated random number generator for sampling to avoid
  23. # interfering with other random processes (e.g., in model initialization).
  24. self.py_rng = random.Random(self.config.seed)
  25. # Set paths and number of samples based on the data type.
  26. if data_type == 'train':
  27. self.data_path = f"{self.config.dataset_path}/train_data.pkl"
  28. self.n_samples = self.config.n_train_iter
  29. else:
  30. self.data_path = f"{self.config.dataset_path}/val_data.pkl"
  31. self.n_samples = self.config.n_val_iter
  32. with open(self.data_path, 'rb') as f:
  33. self.data = pickle.load(f)
  34. self.window = self.config.lookback_window + self.config.predict_window + 1
  35. self.symbols = list(self.data.keys())
  36. self.feature_list = self.config.feature_list
  37. self.time_feature_list = self.config.time_feature_list
  38. # Pre-compute all possible (symbol, start_index) pairs.
  39. self.indices = []
  40. print(f"[{data_type.upper()}] Pre-computing sample indices...")
  41. for symbol in self.symbols:
  42. df = self.data[symbol].reset_index()
  43. series_len = len(df)
  44. num_samples = series_len - self.window + 1
  45. if num_samples > 0:
  46. # Generate time features and store them directly in the dataframe.
  47. df['minute'] = df['datetime'].dt.minute
  48. df['hour'] = df['datetime'].dt.hour
  49. df['weekday'] = df['datetime'].dt.weekday
  50. df['day'] = df['datetime'].dt.day
  51. df['month'] = df['datetime'].dt.month
  52. # Keep only necessary columns to save memory.
  53. self.data[symbol] = df[self.feature_list + self.time_feature_list]
  54. # Add all valid starting indices for this symbol to the global list.
  55. for i in range(num_samples):
  56. self.indices.append((symbol, i))
  57. # The effective dataset size is the minimum of the configured iterations
  58. # and the total number of available samples.
  59. self.n_samples = min(self.n_samples, len(self.indices))
  60. print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
  61. def set_epoch_seed(self, epoch: int):
  62. """
  63. Sets a new seed for the random sampler for each epoch. This is crucial
  64. for reproducibility in distributed training.
  65. Args:
  66. epoch (int): The current epoch number.
  67. """
  68. epoch_seed = self.config.seed + epoch
  69. self.py_rng.seed(epoch_seed)
  70. def __len__(self) -> int:
  71. """Returns the number of samples per epoch."""
  72. return self.n_samples
  73. def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
  74. """
  75. Retrieves a random sample from the dataset.
  76. Note: The `idx` argument is ignored. Instead, a random index is drawn
  77. from the pre-computed `self.indices` list using `self.py_rng`. This
  78. ensures random sampling over the entire dataset for each call.
  79. Args:
  80. idx (int): Ignored.
  81. Returns:
  82. tuple[torch.Tensor, torch.Tensor]: A tuple containing:
  83. - x_tensor (torch.Tensor): The normalized feature tensor.
  84. - x_stamp_tensor (torch.Tensor): The time feature tensor.
  85. """
  86. # Select a random sample from the entire pool of indices.
  87. random_idx = self.py_rng.randint(0, len(self.indices) - 1)
  88. symbol, start_idx = self.indices[random_idx]
  89. # Extract the sliding window from the dataframe.
  90. df = self.data[symbol]
  91. end_idx = start_idx + self.window
  92. win_df = df.iloc[start_idx:end_idx]
  93. # Separate main features and time features.
  94. x = win_df[self.feature_list].values.astype(np.float32)
  95. x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
  96. # Perform instance-level normalization.
  97. x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
  98. x = (x - x_mean) / (x_std + 1e-5)
  99. x = np.clip(x, -self.config.clip, self.config.clip)
  100. # Convert to PyTorch tensors.
  101. x_tensor = torch.from_numpy(x)
  102. x_stamp_tensor = torch.from_numpy(x_stamp)
  103. return x_tensor, x_stamp_tensor
  104. if __name__ == '__main__':
  105. # Example usage and verification.
  106. print("Creating training dataset instance...")
  107. train_dataset = QlibDataset(data_type='train')
  108. print(f"Dataset length: {len(train_dataset)}")
  109. if len(train_dataset) > 0:
  110. try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
  111. print(f"Sample feature shape: {try_x.shape}")
  112. print(f"Sample time feature shape: {try_x_stamp.shape}")
  113. else:
  114. print("Dataset is empty.")