From deef3bc819ffe18f022d81cafd0e18e955735128 Mon Sep 17 00:00:00 2001 From: Timo Schneider Date: Mon, 12 May 2025 11:06:57 +0200 Subject: [PATCH] optimized m_estimator even more --- src/pipeline/m_estimator.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/pipeline/m_estimator.rs b/src/pipeline/m_estimator.rs index 08f8e82..efabb34 100644 --- a/src/pipeline/m_estimator.rs +++ b/src/pipeline/m_estimator.rs @@ -35,15 +35,16 @@ pub fn create_mestimator_thread(lanes_rx: Receiver) -> (Recei let H_t = H.t(); for _ in 0..3 { - let res = lanes - .iter() - .map(|(point, label)| point[0] - (if *label == 0 {0.5} else {-0.5} * z[0] - z[1] - point[1] * z[2] + 0.5 * z[3] * point[1].powi(2))) - .map(|r| 1.0/(1.0 + (r/c).powi(2))) - .collect::>(); + let mut w = Array2::zeros((lanes.len(), lanes.len())); - let w = Array2::from_diag(&Array1::from_vec(res)); + for (i, (point, label)) in lanes.iter().enumerate() { + let r = point[0] - (if *label == 0 { 0.5 } else { -0.5 } * z[0] - z[1] - point[1] * z[2] + 0.5 * z[3] * point[1].powi(2)); + let diag_value = 1.0 / (1.0 + (r / c).powi(2)); - z = H_t.dot(&w).dot(&H).dot(&H_t).inv().unwrap().dot(&w).dot(&y).to_vec(); + w[[i, i]] = diag_value; + } + + z = H_t.dot(&w).dot(&H).inv().unwrap().dot(&H_t).dot(&w).dot(&y).to_vec(); } let end = Instant::now();