背景
正在开发的rust项目要求在no_std的环境下运行,因此会遇到许多库中的函数无法使用的情况,记录在开发过程中针对这类问题做的修改,使整体项目能够在no_std环境中运行。
rand库简介
rand库是实现随机数生成等功能的重要库,但是其中很多方法不支持在no_std下使用,见官网描述:
在no_std环境下实现rand::distributions::WeightedIndex
rand::distributions::WeightedIndex运行在no_std环境中会报如下错误:
unresolved import `rand::distributions::WeightedIndex`
no `WeightedIndex` in `distributions`
首先,尝试搜索实现distribution的支持no_std的rust库,无果。
因此参照rand::distributions::weighted_index的源码手动实现,源码链接如下:https://docs.rs/rand/0.8.3/src/rand/distributions/weighted_index.rs.html#81-85
实现方法异常简单,将除test部分复制到需要使用weighted_index的语句之前,注释掉原来的引用语句:
use rand::distributions::WeightedIndex;
即可正常使用。
实现源代码
源代码如下:
// 自己实现rand::distributions::weighted::WeightedIndex
// 参照:https://docs.rs/rand/0.8.3/src/rand/distributions/weighted_index.rs.html#81-85
use rand::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
use rand::distributions::Distribution;
use rand::Rng;
// use core::cmp::PartialOrd;
use core::fmt;
// use alloc::vec::Vec;
#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
weight_distribution: X::Sampler,
}
impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// Creates a new a `WeightedIndex` [`Distribution`] using the values
/// in `weights`. The weights can use any type `X` for which an
/// implementation of [`Uniform<X>`] exists.
///
/// Returns an error if the iterator is empty, if any weight is `< 0`, or
/// if its total value is 0.
///
/// [`Uniform<X>`]: crate::distributions::uniform::Uniform
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
where
I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
{
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
let zero = <X as Default>::default();
if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}
let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
for w in iter {
// Note that `!(w >= x)` is not equivalent to `w < x` for partially
// ordered types due to NaNs which are equal to nothing.
if !(w.borrow() >= &zero) {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
total_weight += w.borrow();
}
if total_weight == zero {
return Err(WeightedError::AllWeightsZero);
}
let distr = X::Sampler::new(zero, total_weight.clone());
Ok(WeightedIndex {
cumulative_weights: weights,
total_weight,
weight_distribution: distr,
})
}
/// Update a subset of weights, without changing the number of weights.
///
/// `new_weights` must be sorted by the index.
///
/// Using this method instead of `new` might be more efficient if only a small number of
/// weights is modified. No allocations are performed, unless the weight type `X` uses
/// allocation internally.
///
/// In case of error, `self` is not modified.
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
where X: for<'a> ::core::ops::AddAssign<&'a X>
+ for<'a> ::core::ops::SubAssign<&'a X>
+ Clone
+ Default {
if new_weights.is_empty() {
return Ok(());
}
let zero = <X as Default>::default();
let mut total_weight = self.total_weight.clone();
// Check for errors first, so we don't modify `self` in case something
// goes wrong.
let mut prev_i = None;
for &(i, w) in new_weights {
if let Some(old_i) = prev_i {
if old_i >= i {
return Err(WeightedError::InvalidWeight);
}
}
if !(*w >= zero) {
return Err(WeightedError::InvalidWeight);
}
if i > self.cumulative_weights.len() {
return Err(WeightedError::TooMany);
}
let mut old_w = if i < self.cumulative_weights.len() {
self.cumulative_weights[i].clone()
} else {
self.total_weight.clone()
};
if i > 0 {
old_w -= &self.cumulative_weights[i - 1];
}
total_weight -= &old_w;
total_weight += w;
prev_i = Some(i);
}
if total_weight <= zero {
return Err(WeightedError::AllWeightsZero);
}
// Update the weights. Because we checked all the preconditions in the
// previous loop, this should never panic.
let mut iter = new_weights.iter();
let mut prev_weight = zero.clone();
let mut next_new_weight = iter.next();
let &(first_new_index, _) = next_new_weight.unwrap();
let mut cumulative_weight = if first_new_index > 0 {
self.cumulative_weights[first_new_index - 1].clone()
} else {
zero.clone()
};
for i in first_new_index..self.cumulative_weights.len() {
match next_new_weight {
Some(&(j, w)) if i == j => {
cumulative_weight += w;
next_new_weight = iter.next();
}
_ => {
let mut tmp = self.cumulative_weights[i].clone();
tmp -= &prev_weight; // We know this is positive.
cumulative_weight += &tmp;
}
}
prev_weight = cumulative_weight.clone();
core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
}
self.total_weight = total_weight;
self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
Ok(())
}
}
impl<X> Distribution<usize> for WeightedIndex<X>
where X: SampleUniform + PartialOrd
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
use ::core::cmp::Ordering;
let chosen_weight = self.weight_distribution.sample(rng);
// Find the first item which has a weight *higher* than the chosen weight.
self.cumulative_weights
.binary_search_by(|w| {
if *w <= chosen_weight {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_err()
}
}
/// Error type returned from `WeightedIndex::new`.
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightedError {
/// The provided weight collection contains no items.
NoItem,
/// A weight is either less than zero, greater than the supported maximum,
/// NaN, or otherwise invalid.
InvalidWeight,
/// All items in the provided weight collection are zero.
AllWeightsZero,
/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,
}
#[cfg(feature = "std")]
impl ::std::error::Error for WeightedError {}
impl fmt::Display for WeightedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
WeightedError::NoItem => write!(f, "No weights provided."),
WeightedError::InvalidWeight => write!(f, "A weight is invalid."),
WeightedError::AllWeightsZero => write!(f, "All weights are zero."),
WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"),
}
}
}