【Rust开发】在no_std项目中使用rand::distributions::WeightedIndex

背景

正在开发的rust项目要求在no_std的环境下运行,因此会遇到许多库中的函数无法使用的情况,记录在开发过程中针对这类问题做的修改,使整体项目能够在no_std环境中运行。

rand库简介

rand库是实现随机数生成等功能的重要库,但是其中很多方法不支持在no_std下使用,见官网描述:

在no_std环境下实现rand::distributions::WeightedIndex

rand::distributions::WeightedIndex运行在no_std环境中会报如下错误:

unresolved import `rand::distributions::WeightedIndex`
no `WeightedIndex` in `distributions`

首先,尝试搜索实现distribution的支持no_std的rust库,无果。
因此参照rand::distributions::weighted_index的源码手动实现,源码链接如下:https://docs.rs/rand/0.8.3/src/rand/distributions/weighted_index.rs.html#81-85
实现方法异常简单,将除test部分复制到需要使用weighted_index的语句之前,注释掉原来的引用语句:

use rand::distributions::WeightedIndex;

即可正常使用。

实现源代码

源代码如下:

        // 自己实现rand::distributions::weighted::WeightedIndex
        // 参照:https://docs.rs/rand/0.8.3/src/rand/distributions/weighted_index.rs.html#81-85
        use rand::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
        use rand::distributions::Distribution;
        use rand::Rng;
        // use core::cmp::PartialOrd;
        use core::fmt;
        // use alloc::vec::Vec;
        #[cfg(feature = "serde1")]
        use serde::{Serialize, Deserialize};
        #[derive(Debug, Clone)]
        #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
        #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
        pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
            cumulative_weights: Vec<X>,
            total_weight: X,
            weight_distribution: X::Sampler,
        }

        impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
            /// Creates a new a `WeightedIndex` [`Distribution`] using the values
            /// in `weights`. The weights can use any type `X` for which an
            /// implementation of [`Uniform<X>`] exists.
            ///
            /// Returns an error if the iterator is empty, if any weight is `< 0`, or
            /// if its total value is 0.
            ///
            /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
            pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
            where
                I: IntoIterator,
                I::Item: SampleBorrow<X>,
                X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
            {
                let mut iter = weights.into_iter();
                let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

                let zero = <X as Default>::default();
                if !(total_weight >= zero) {
                    return Err(WeightedError::InvalidWeight);
                }

                let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
                for w in iter {
                    // Note that `!(w >= x)` is not equivalent to `w < x` for partially
                    // ordered types due to NaNs which are equal to nothing.
                    if !(w.borrow() >= &zero) {
                        return Err(WeightedError::InvalidWeight);
                    }
                    weights.push(total_weight.clone());
                    total_weight += w.borrow();
                }

                if total_weight == zero {
                    return Err(WeightedError::AllWeightsZero);
                }
                let distr = X::Sampler::new(zero, total_weight.clone());

                Ok(WeightedIndex {
                    cumulative_weights: weights,
                    total_weight,
                    weight_distribution: distr,
                })
            }

            /// Update a subset of weights, without changing the number of weights.
            ///
            /// `new_weights` must be sorted by the index.
            ///
            /// Using this method instead of `new` might be more efficient if only a small number of
            /// weights is modified. No allocations are performed, unless the weight type `X` uses
            /// allocation internally.
            ///
            /// In case of error, `self` is not modified.
            pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
            where X: for<'a> ::core::ops::AddAssign<&'a X>
                    + for<'a> ::core::ops::SubAssign<&'a X>
                    + Clone
                    + Default {
                if new_weights.is_empty() {
                    return Ok(());
                }

                let zero = <X as Default>::default();

                let mut total_weight = self.total_weight.clone();

                // Check for errors first, so we don't modify `self` in case something
                // goes wrong.
                let mut prev_i = None;
                for &(i, w) in new_weights {
                    if let Some(old_i) = prev_i {
                        if old_i >= i {
                            return Err(WeightedError::InvalidWeight);
                        }
                    }
                    if !(*w >= zero) {
                        return Err(WeightedError::InvalidWeight);
                    }
                    if i > self.cumulative_weights.len() {
                        return Err(WeightedError::TooMany);
                    }

                    let mut old_w = if i < self.cumulative_weights.len() {
                        self.cumulative_weights[i].clone()
                    } else {
                        self.total_weight.clone()
                    };
                    if i > 0 {
                        old_w -= &self.cumulative_weights[i - 1];
                    }

                    total_weight -= &old_w;
                    total_weight += w;
                    prev_i = Some(i);
                }
                if total_weight <= zero {
                    return Err(WeightedError::AllWeightsZero);
                }

                // Update the weights. Because we checked all the preconditions in the
                // previous loop, this should never panic.
                let mut iter = new_weights.iter();

                let mut prev_weight = zero.clone();
                let mut next_new_weight = iter.next();
                let &(first_new_index, _) = next_new_weight.unwrap();
                let mut cumulative_weight = if first_new_index > 0 {
                    self.cumulative_weights[first_new_index - 1].clone()
                } else {
                    zero.clone()
                };
                for i in first_new_index..self.cumulative_weights.len() {
                    match next_new_weight {
                        Some(&(j, w)) if i == j => {
                            cumulative_weight += w;
                            next_new_weight = iter.next();
                        }
                        _ => {
                            let mut tmp = self.cumulative_weights[i].clone();
                            tmp -= &prev_weight; // We know this is positive.
                            cumulative_weight += &tmp;
                        }
                    }
                    prev_weight = cumulative_weight.clone();
                    core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
                }

                self.total_weight = total_weight;
                self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());

                Ok(())
            }
        }

        impl<X> Distribution<usize> for WeightedIndex<X>
        where X: SampleUniform + PartialOrd
        {
            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
                use ::core::cmp::Ordering;
                let chosen_weight = self.weight_distribution.sample(rng);
                // Find the first item which has a weight *higher* than the chosen weight.
                self.cumulative_weights
                    .binary_search_by(|w| {
                        if *w <= chosen_weight {
                            Ordering::Less
                        } else {
                            Ordering::Greater
                        }
                    })
                    .unwrap_err()
            }
        }
        /// Error type returned from `WeightedIndex::new`.
        #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
        pub enum WeightedError {
            /// The provided weight collection contains no items.
            NoItem,

            /// A weight is either less than zero, greater than the supported maximum,
            /// NaN, or otherwise invalid.
            InvalidWeight,

            /// All items in the provided weight collection are zero.
            AllWeightsZero,

            /// Too many weights are provided (length greater than `u32::MAX`)
            TooMany,
        }

        #[cfg(feature = "std")]
        impl ::std::error::Error for WeightedError {}

        impl fmt::Display for WeightedError {
            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
                match *self {
                    WeightedError::NoItem => write!(f, "No weights provided."),
                    WeightedError::InvalidWeight => write!(f, "A weight is invalid."),
                    WeightedError::AllWeightsZero => write!(f, "All weights are zero."),
                    WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"),
                }
            }
        }

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值