diff --git a/crates/fidc-core/src/data.rs b/crates/fidc-core/src/data.rs index f31911d..58d9370 100644 --- a/crates/fidc-core/src/data.rs +++ b/crates/fidc-core/src/data.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::fs; use std::path::Path; +use std::sync::Arc; use chrono::{NaiveDate, NaiveDateTime}; use serde::{Deserialize, Serialize}; @@ -955,12 +956,12 @@ pub struct DataSet { calendar: TradingCalendar, market_by_date: BTreeMap>, market_index: HashMap<(NaiveDate, String), DailyMarketSnapshot>, - factor_by_date: BTreeMap>, - factor_index: HashMap<(NaiveDate, String), DailyFactorSnapshot>, + factor_by_date: BTreeMap>>, + factor_index: HashMap<(NaiveDate, String), Arc>, factor_text_by_date: BTreeMap>, factor_text_index: HashMap<(NaiveDate, String, String), FactorTextValue>, - candidate_by_date: BTreeMap>, - candidate_index: HashMap<(NaiveDate, String), CandidateEligibility>, + candidate_by_date: BTreeMap>>, + candidate_index: HashMap<(NaiveDate, String), Arc>, corporate_actions_by_date: BTreeMap>, execution_quotes_by_date: HashMap>>, order_book_depth_index: HashMap<(NaiveDate, String), Vec>, @@ -1205,7 +1206,11 @@ impl DataSet { ) -> Result { let benchmark_code = collect_benchmark_code(&benchmarks)?; let calendar = TradingCalendar::new(benchmarks.iter().map(|item| item.date).collect()); - let factors = normalize_factor_snapshots(factors); + let factors = normalize_factor_snapshots(factors) + .into_iter() + .map(Arc::new) + .collect::>(); + let candidates = candidates.into_iter().map(Arc::new).collect::>(); let instruments = instruments .into_iter() @@ -1218,7 +1223,7 @@ impl DataSet { .map(|item| ((item.date, item.symbol.clone()), item)) .collect::>(); - let factor_by_date = group_by_date(factors.clone(), |item| item.date); + let factor_by_date = group_arc_by_date(&factors, |item| item.date); let factor_index = factors .into_iter() .map(|item| ((item.date, item.symbol.clone()), item)) @@ -1240,7 +1245,7 @@ impl DataSet { .map(|item| ((item.date, item.symbol.clone(), item.field.clone()), item)) .collect::>(); - let candidate_by_date = group_by_date(candidates.clone(), |item| item.date); + let candidate_by_date = group_arc_by_date(&candidates, |item| item.date); let candidate_index = candidates .into_iter() .map(|item| ((item.date, item.symbol.clone()), item)) @@ -1329,11 +1334,15 @@ impl DataSet { } pub fn factor(&self, date: NaiveDate, symbol: &str) -> Option<&DailyFactorSnapshot> { - self.factor_index.get(&(date, symbol.to_string())) + self.factor_index + .get(&(date, symbol.to_string())) + .map(Arc::as_ref) } pub fn candidate(&self, date: NaiveDate, symbol: &str) -> Option<&CandidateEligibility> { - self.candidate_index.get(&(date, symbol.to_string())) + self.candidate_index + .get(&(date, symbol.to_string())) + .map(Arc::as_ref) } pub fn benchmark(&self, date: NaiveDate) -> Option<&BenchmarkSnapshot> { @@ -2089,7 +2098,7 @@ impl DataSet { pub fn factor_snapshots_on(&self, date: NaiveDate) -> Vec<&DailyFactorSnapshot> { self.factor_by_date .get(&date) - .map(|rows| rows.iter().collect()) + .map(|rows| rows.iter().map(Arc::as_ref).collect()) .unwrap_or_default() } @@ -2110,7 +2119,7 @@ impl DataSet { pub fn candidate_snapshots_on(&self, date: NaiveDate) -> Vec<&CandidateEligibility> { self.candidate_by_date .get(&date) - .map(|rows| rows.iter().collect()) + .map(|rows| rows.iter().map(Arc::as_ref).collect()) .unwrap_or_default() } @@ -2123,11 +2132,15 @@ impl DataSet { date, benchmark, market: self.market_by_date.get(&date).cloned().unwrap_or_default(), - factors: self.factor_by_date.get(&date).cloned().unwrap_or_default(), + factors: self + .factor_by_date + .get(&date) + .map(|rows| rows.iter().map(|row| row.as_ref().clone()).collect()) + .unwrap_or_default(), candidates: self .candidate_by_date .get(&date) - .cloned() + .map(|rows| rows.iter().map(|row| row.as_ref().clone()).collect()) .unwrap_or_default(), corporate_actions: self .corporate_actions_by_date @@ -2200,7 +2213,9 @@ impl DataSet { .iter() .map(|day| { evaluator( - self.candidate_index.get(&(*day, symbol.to_string())), + self.candidate_index + .get(&(*day, symbol.to_string())) + .map(Arc::as_ref), self.market_index.get(&(*day, symbol.to_string())), ) }) @@ -3375,6 +3390,20 @@ where grouped } +fn group_arc_by_date(rows: &[Arc], mut date_of: F) -> BTreeMap>> +where + F: FnMut(&T) -> NaiveDate, +{ + let mut grouped = BTreeMap::>>::new(); + for row in rows { + grouped + .entry(date_of(row.as_ref())) + .or_default() + .push(Arc::clone(row)); + } + grouped +} + fn collect_benchmark_code(benchmarks: &[BenchmarkSnapshot]) -> Result { let mut codes = benchmarks .iter() @@ -3523,8 +3552,8 @@ fn build_order_book_depth_index( } fn build_eligible_universe( - factor_by_date: &BTreeMap>, - candidate_index: &HashMap<(NaiveDate, String), CandidateEligibility>, + factor_by_date: &BTreeMap>>, + candidate_index: &HashMap<(NaiveDate, String), Arc>, market_index: &HashMap<(NaiveDate, String), DailyMarketSnapshot>, instruments: &HashMap, ) -> BTreeMap> {