增强回测demo输出与分区加载

This commit is contained in:
zsb
2026-04-07 21:25:41 -07:00
parent ec425999b0
commit a26049ff15
9 changed files with 211 additions and 63 deletions

View File

@@ -87,6 +87,8 @@ pub struct DailyFactorSnapshot {
pub market_cap_bn: f64,
pub free_float_cap_bn: f64,
pub pe_ttm: f64,
pub turnover_ratio: Option<f64>,
pub effective_turnover_ratio: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -303,6 +305,14 @@ impl DataSet {
.collect()
}
pub fn market_closes_up_to(&self, date: NaiveDate, symbol: &str, lookback: usize) -> Vec<f64> {
self.calendar
.trailing_days(date, lookback)
.into_iter()
.filter_map(|day| self.market(day, symbol).map(|row| row.close))
.collect()
}
pub fn require_market(
&self,
date: NaiveDate,
@@ -347,6 +357,8 @@ fn read_market(path: &Path) -> Result<Vec<DailyMarketSnapshot>, DataSetError> {
let mut snapshots = Vec::new();
for row in rows {
let prev_close = row.parse_f64(6)?;
let derived_upper_limit = round2(prev_close * 1.10);
let derived_lower_limit = round2(prev_close * 0.90);
snapshots.push(DailyMarketSnapshot {
date: row.parse_date(0)?,
symbol: row.get(1)?.to_string(),
@@ -357,8 +369,8 @@ fn read_market(path: &Path) -> Result<Vec<DailyMarketSnapshot>, DataSetError> {
prev_close,
volume: row.parse_u64(7)?,
paused: row.parse_bool(8)?,
upper_limit: round2(prev_close * 1.10),
lower_limit: round2(prev_close * 0.90),
upper_limit: row.parse_optional_f64(9).unwrap_or(derived_upper_limit),
lower_limit: row.parse_optional_f64(10).unwrap_or(derived_lower_limit),
});
}
Ok(snapshots)
@@ -374,6 +386,8 @@ fn read_factors(path: &Path) -> Result<Vec<DailyFactorSnapshot>, DataSetError> {
market_cap_bn: row.parse_f64(2)?,
free_float_cap_bn: row.parse_f64(3)?,
pe_ttm: row.parse_f64(4)?,
turnover_ratio: row.parse_optional_f64(5),
effective_turnover_ratio: row.parse_optional_f64(6),
});
}
Ok(snapshots)
@@ -457,6 +471,17 @@ impl CsvRow {
})
}
fn parse_optional_f64(&self, index: usize) -> Option<f64> {
self.fields.get(index).and_then(|value| {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
trimmed.parse::<f64>().ok()
}
})
}
fn parse_bool(&self, index: usize) -> Result<bool, DataSetError> {
self.get(index)?
.parse::<bool>()
@@ -478,26 +503,35 @@ fn read_partitioned_dir<T, F>(dir: &Path, mut loader: F) -> Result<Vec<T>, DataS
where
F: FnMut(&Path) -> Result<Vec<T>, DataSetError>,
{
let mut files = fs::read_dir(dir)
.map_err(|source| DataSetError::Io {
path: dir.display().to_string(),
source,
})?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| DataSetError::Io {
path: dir.display().to_string(),
source,
})?;
files.sort_by_key(|entry| entry.path());
let mut rows = Vec::new();
for entry in files {
let path = entry.path();
if path.extension().and_then(|x| x.to_str()) != Some("csv") {
continue;
let mut stack = vec![dir.to_path_buf()];
while let Some(current_dir) = stack.pop() {
let mut entries = fs::read_dir(&current_dir)
.map_err(|source| DataSetError::Io {
path: current_dir.display().to_string(),
source,
})?
.collect::<Result<Vec<_>, _>>()
.map_err(|source| DataSetError::Io {
path: current_dir.display().to_string(),
source,
})?;
entries.sort_by_key(|entry| entry.path());
for entry in entries.into_iter().rev() {
let path = entry.path();
if path.is_dir() {
stack.push(path);
continue;
}
if path.extension().and_then(|x| x.to_str()) != Some("csv") {
continue;
}
rows.extend(loader(&path)?);
}
rows.extend(loader(&path)?);
}
Ok(rows)
}

View File

@@ -43,6 +43,7 @@ pub struct CnSmallCapRotationConfig {
pub trade_rate: f64,
pub stop_loss_pct: f64,
pub take_profit_pct: f64,
pub signal_symbol: Option<String>,
}
impl CnSmallCapRotationConfig {
@@ -60,6 +61,7 @@ impl CnSmallCapRotationConfig {
trade_rate: 0.5,
stop_loss_pct: 0.08,
take_profit_pct: 0.10,
signal_symbol: None,
}
}
}
@@ -157,10 +159,20 @@ impl Strategy for CnSmallCapRotationStrategy {
.ok_or(BacktestError::MissingBenchmark {
date: ctx.decision_date,
})?;
let benchmark_closes = ctx
.data
.benchmark_closes_up_to(ctx.decision_date, self.config.long_ma_days);
let gross_exposure = self.gross_exposure(&benchmark_closes);
let signal_symbol = self.config.signal_symbol.as_deref();
let signal_closes = if let Some(symbol) = signal_symbol {
ctx.data.market_closes_up_to(ctx.decision_date, symbol, self.config.long_ma_days)
} else {
ctx.data.benchmark_closes_up_to(ctx.decision_date, self.config.long_ma_days)
};
let signal_level = if let Some(symbol) = signal_symbol {
ctx.data
.price(ctx.decision_date, symbol, PriceField::Close)
.unwrap_or(benchmark.close)
} else {
benchmark.close
};
let gross_exposure = self.gross_exposure(&signal_closes);
let periodic_rebalance = ctx.decision_index % self.config.refresh_rate == 0;
let exposure_changed = self
.last_gross_exposure
@@ -175,8 +187,10 @@ impl Strategy for CnSmallCapRotationStrategy {
ctx.decision_date, ctx.execution_date, gross_exposure
)];
let mut diagnostics = vec![format!(
"benchmark_close={:.2} refresh_rate={} stocknum={} short_ma_days={} long_ma_days={}",
"benchmark_close={:.2} signal_level={:.2} signal_symbol={} refresh_rate={} stocknum={} short_ma_days={} long_ma_days={}",
benchmark.close,
signal_level,
signal_symbol.unwrap_or(benchmark.benchmark.as_str()),
self.config.refresh_rate,
self.config.stocknum,
self.config.short_ma_days,
@@ -187,6 +201,7 @@ impl Strategy for CnSmallCapRotationStrategy {
let selected = self.selector.select(&SelectionContext {
decision_date: ctx.decision_date,
benchmark,
reference_level: signal_level,
data: ctx.data,
});

View File

@@ -21,6 +21,7 @@ pub struct UniverseCandidate {
pub struct SelectionContext<'a> {
pub decision_date: NaiveDate,
pub benchmark: &'a BenchmarkSnapshot,
pub reference_level: f64,
pub data: &'a DataSet,
}
@@ -77,8 +78,8 @@ impl DynamicMarketCapBandSelector {
impl UniverseSelector for DynamicMarketCapBandSelector {
fn select(&self, ctx: &SelectionContext<'_>) -> Vec<UniverseCandidate> {
let _regime = self.regime(ctx.benchmark.close);
let (min_cap, max_cap) = self.band_for_level(ctx.benchmark.close);
let _regime = self.regime(ctx.reference_level);
let (min_cap, max_cap) = self.band_for_level(ctx.reference_level);
let mut selected = ctx
.data

View File

@@ -16,10 +16,10 @@ fn temp_dir() -> PathBuf {
#[test]
fn can_load_partitioned_snapshot_dir() {
let dir = temp_dir();
fs::create_dir_all(dir.join("benchmark")).unwrap();
fs::create_dir_all(dir.join("market")).unwrap();
fs::create_dir_all(dir.join("factors")).unwrap();
fs::create_dir_all(dir.join("candidates")).unwrap();
fs::create_dir_all(dir.join("benchmark/2024/01")).unwrap();
fs::create_dir_all(dir.join("market/2024/01")).unwrap();
fs::create_dir_all(dir.join("factors/2024/01")).unwrap();
fs::create_dir_all(dir.join("candidates/2024/01")).unwrap();
fs::write(
dir.join("instruments.csv"),
@@ -27,22 +27,22 @@ fn can_load_partitioned_snapshot_dir() {
)
.unwrap();
fs::write(
dir.join("benchmark/2024-01-02.csv"),
dir.join("benchmark/2024/01/2024-01-02.csv"),
"date,benchmark,open,close,prev_close,volume\n2024-01-02,CSI300.DEMO,2990,3000,2980,100000000\n",
)
.unwrap();
fs::write(
dir.join("market/2024-01-02.csv"),
dir.join("market/2024/01/2024-01-02.csv"),
"date,symbol,open,high,low,close,prev_close,volume,paused,upper_limit,lower_limit\n2024-01-02,000001.SZ,10,10.5,9.9,10.2,10,100000,false,11,9\n",
)
.unwrap();
fs::write(
dir.join("factors/2024-01-02.csv"),
"date,symbol,market_cap_bn,free_float_cap_bn,pe_ttm\n2024-01-02,000001.SZ,40,35,12\n",
dir.join("factors/2024/01/2024-01-02.csv"),
"date,symbol,market_cap_bn,free_float_cap_bn,pe_ttm,turnover_ratio,effective_turnover_ratio\n2024-01-02,000001.SZ,40,35,12,3.2,2.1\n",
)
.unwrap();
fs::write(
dir.join("candidates/2024-01-02.csv"),
dir.join("candidates/2024/01/2024-01-02.csv"),
"date,symbol,is_st,is_new_listing,is_paused,allow_buy,allow_sell,is_kcb,is_one_yuan\n2024-01-02,000001.SZ,false,false,false,true,true,false,false\n",
)
.unwrap();

View File

@@ -31,4 +31,8 @@ fn strategy_emits_target_weights_and_diagnostics() {
.diagnostics
.iter()
.any(|line| line.contains("selected=")));
assert!(decision
.diagnostics
.iter()
.any(|line| line.contains("signal_symbol=")));
}