summaryrefslogtreecommitdiff
path: root/bindings/rs-dablooms/src/dablooms.rs
diff options
context:
space:
mode:
Diffstat (limited to 'bindings/rs-dablooms/src/dablooms.rs')
-rw-r--r--bindings/rs-dablooms/src/dablooms.rs321
1 files changed, 321 insertions, 0 deletions
diff --git a/bindings/rs-dablooms/src/dablooms.rs b/bindings/rs-dablooms/src/dablooms.rs
new file mode 100644
index 0000000..cf74d38
--- /dev/null
+++ b/bindings/rs-dablooms/src/dablooms.rs
@@ -0,0 +1,321 @@
+use libc::time_t;
+
+use crate::dablooms_bind::*;
+use std::ffi::{CStr, CString};
+use std::os::raw::{c_int, c_uint};
+use std::ptr::NonNull;
+use std::time::{SystemTime, UNIX_EPOCH};
+
+// String form rust not contain '\0' in the end, that may cause some problem
+//
+
+pub struct CountingBloomFilter {
+ bloom: NonNull<counting_bloom_t>,
+}
+
+impl CountingBloomFilter {
+ pub fn new(capacity: u32, error_rate: f64) -> Result<Self, &'static str> {
+ let bloom_ptr = unsafe { new_counting_bloom(capacity, error_rate) };
+ let bloom =
+ NonNull::new(bloom_ptr).ok_or("Failed to create CountingBloomFilter,return null")?;
+ Ok(CountingBloomFilter { bloom })
+ }
+
+ // get raw pointer
+ fn get_raw(&self) -> *mut counting_bloom_t {
+ self.bloom.as_ptr()
+ }
+
+ pub fn add<T>(&mut self, key: T) -> Result<(), ()>
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let res =
+ unsafe { counting_bloom_add(self.get_raw(), key.as_ptr() as *const i8, key.len()) };
+ if res == 0 {
+ Ok(())
+ } else {
+ Err(())
+ }
+ }
+
+ pub fn remove<T>(&mut self, key: T) -> Result<(), ()>
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let res =
+ unsafe { counting_bloom_remove(self.get_raw(), key.as_ptr() as *const i8, key.len()) };
+ if res == 0 {
+ Ok(())
+ } else {
+ Err(())
+ }
+ }
+
+ pub fn check<T>(&self, key: T) -> bool
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ unsafe { counting_bloom_check(self.get_raw(), key.as_ptr() as *const i8, key.len()) > 0 }
+ }
+}
+
+impl Drop for CountingBloomFilter {
+ fn drop(&mut self) {
+ unsafe { free_counting_bloom(self.get_raw()) };
+ }
+}
+
+pub struct ScalingBloomFilter {
+ bloom: NonNull<scaling_bloom_t>,
+}
+
+impl ScalingBloomFilter {
+ pub fn new(capacity: u32, error_rate: f64) -> Result<Self, &'static str> {
+ let bloom_ptr = unsafe { new_scaling_bloom(capacity as c_uint, error_rate) };
+ let bloom =
+ NonNull::new(bloom_ptr).ok_or("Failed to create ScalingBloomFilter,return null")?;
+ Ok(ScalingBloomFilter { bloom })
+ }
+
+ // get raw pointer
+ fn get_raw(&self) -> *mut scaling_bloom_t {
+ self.bloom.as_ptr()
+ }
+
+ pub fn add<T>(&mut self, key: T, id: u64) -> Result<(), ()>
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let res =
+ unsafe { scaling_bloom_add(self.get_raw(), key.as_ptr() as *const i8, key.len(), id) };
+ // scaling_bloom_add returns 1 if the key is ok...
+ // no idea why it's not 0...
+ if res == 1 {
+ Ok(())
+ } else {
+ Err(())
+ }
+ }
+
+ pub fn remove<T>(&mut self, key: T, id: u64) -> Result<(), ()>
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let res = unsafe {
+ scaling_bloom_remove(self.get_raw(), key.as_ptr() as *const i8, key.len(), id)
+ };
+ // scaling_bloom_remove returns 1 if remove success...
+ // no idea why it's not 0...
+ if res == 1 {
+ Ok(())
+ } else {
+ Err(())
+ }
+ }
+
+ pub fn check<T>(&self, key: T) -> bool
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ unsafe { scaling_bloom_check(self.get_raw(), key.as_ptr() as *const i8, key.len()) > 0 }
+ }
+
+ pub fn flush(&mut self) -> Result<(), ()> {
+ let res = unsafe { scaling_bloom_flush(self.get_raw()) };
+ if res == 0 {
+ Ok(())
+ } else {
+ Err(())
+ }
+ }
+
+ pub fn mem_seqnum(&self) -> u64 {
+ unsafe { scaling_bloom_mem_seqnum(self.get_raw()) }
+ }
+
+ pub fn disk_seqnum(&self) -> u64 {
+ unsafe { scaling_bloom_disk_seqnum(self.get_raw()) }
+ }
+}
+
+impl Drop for ScalingBloomFilter {
+ fn drop(&mut self) {
+ unsafe { free_scaling_bloom(self.get_raw()) };
+ }
+}
+
+fn get_current_time() -> Result<time_t, c_int> {
+ let current_time = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .map_err(|_| -1)?
+ .as_secs() as time_t;
+ Ok(current_time)
+}
+
+pub struct ExpiryBloomFilter {
+ bloom: NonNull<expiry_dablooms_handle>,
+}
+
+impl ExpiryBloomFilter {
+ pub fn new(capacity: u32, error_rate: f64, expiry_time: i32) -> Result<Self, &'static str> {
+ let cur_time = get_current_time();
+ let bloom_ptr = unsafe {
+ expiry_dablooms_init(
+ capacity as c_uint,
+ error_rate,
+ cur_time.unwrap(),
+ expiry_time as c_int,
+ )
+ };
+ let bloom =
+ NonNull::new(bloom_ptr).ok_or("Failed to create ExpiryBloomFilter,return null")?;
+ Ok(ExpiryBloomFilter { bloom })
+ }
+
+ // get raw pointer
+ fn get_raw(&self) -> *mut expiry_dablooms_handle {
+ self.bloom.as_ptr()
+ }
+
+ pub fn add<T>(&mut self, key: T) -> Result<(), String>
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let cur_time = get_current_time();
+ let res = unsafe {
+ expiry_dablooms_add(
+ self.get_raw(),
+ key.as_ptr() as *const i8,
+ key.len(),
+ cur_time.unwrap(),
+ )
+ };
+ if res == 0 {
+ Ok(())
+ } else {
+ let err_cstr = unsafe { expiry_dablooms_errno_trans(res) };
+ let err_msg = unsafe { CStr::from_ptr(err_cstr).to_string_lossy().into_owned() };
+ Err(err_msg)
+ }
+ }
+
+ pub fn check<T>(&self, key: T) -> bool
+ where
+ T: AsRef<[u8]>,
+ {
+ let key = key.as_ref();
+ let cur_time = get_current_time();
+ unsafe {
+ expiry_dablooms_search(
+ self.get_raw(),
+ key.as_ptr() as *const i8,
+ key.len(),
+ cur_time.unwrap(),
+ ) > 0
+ }
+ }
+
+ pub fn count(&self) -> i32 {
+ let mut count: u64 = 0;
+ unsafe {
+ expiry_dablooms_element_count_get(self.get_raw(), &mut count);
+ }
+ return count as i32;
+ }
+}
+
+impl Drop for ExpiryBloomFilter {
+ fn drop(&mut self) {
+ unsafe { expiry_dablooms_destroy(self.get_raw()) };
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use libc::sleep;
+
+ use crate::dablooms::*;
+
+ #[test]
+ fn test_counting_bloom_filter() {
+ let mut bloom = CountingBloomFilter::new(1000, 0.01).unwrap();
+ let key1 = "hello";
+ let key2 = "world";
+ let key3 = "rust";
+
+ assert!(bloom.check(key1) == false);
+ assert!(bloom.check(key2) == false);
+ assert!(bloom.check(key3) == false);
+ bloom.add(key1).unwrap();
+ bloom.add(key2).unwrap();
+
+ drop(key1);
+ let key = "hello";
+
+ assert!(bloom.check(key) == true);
+ assert!(bloom.check(key2) == true);
+ assert!(bloom.check(key3) == false);
+ bloom.remove(key).unwrap();
+ assert!(bloom.check(key1) == false);
+ assert!(bloom.check(key2) == true);
+ assert!(bloom.check(key3) == false);
+ }
+
+ #[test]
+ fn test_scaling_bloom_filter() {
+ let mut bloom = ScalingBloomFilter::new(100, 0.05).unwrap();
+ let key1 = "aaa";
+ let key2 = "bbb";
+ let id1 = 1;
+ let id2 = 2;
+ assert!(bloom.check(key1) == false);
+ assert!(bloom.check(key2) == false);
+
+ bloom.add(key1, id1).unwrap();
+ bloom.add(key2, id2).unwrap();
+ assert!(bloom.check(key1) == true);
+ assert!(bloom.check(key2) == true);
+ bloom.remove(key1, id1).unwrap();
+ assert!(bloom.check(key1) == false);
+ assert!(bloom.check(key2) == true);
+ }
+
+ #[test]
+ fn test_expiry_dablooms_filter() {
+ let capacity: u32 = 100;
+ let error_rate: f64 = 0.05;
+ let expiry_secs: i32 = 3; // expiry time 1s
+ let key1 = "aaa";
+ let key2 = "bbb";
+
+ let mut bloom = ExpiryBloomFilter::new(capacity, error_rate, expiry_secs).unwrap();
+
+ for _i in 1..7 {
+ bloom.add(key1).unwrap();
+ assert!(bloom.count() == _i);
+ }
+ assert!(bloom.check(key1));
+
+ unsafe {
+ sleep(2); // sleep 2 sec | all key1's value not expired
+ }
+ assert!(bloom.check(key1));
+ unsafe {
+ sleep(2); // sleep 4 sec | all key1's value expired
+ }
+ assert!(!bloom.check(key1));
+
+ for _i in 1..7 {
+ bloom.add(key2).unwrap();
+ assert!(bloom.count() == _i);
+ }
+ }
+}