diff options
Diffstat (limited to 'bindings/rs-dablooms/src/dablooms.rs')
| -rw-r--r-- | bindings/rs-dablooms/src/dablooms.rs | 321 |
1 files changed, 321 insertions, 0 deletions
diff --git a/bindings/rs-dablooms/src/dablooms.rs b/bindings/rs-dablooms/src/dablooms.rs new file mode 100644 index 0000000..cf74d38 --- /dev/null +++ b/bindings/rs-dablooms/src/dablooms.rs @@ -0,0 +1,321 @@ +use libc::time_t; + +use crate::dablooms_bind::*; +use std::ffi::{CStr, CString}; +use std::os::raw::{c_int, c_uint}; +use std::ptr::NonNull; +use std::time::{SystemTime, UNIX_EPOCH}; + +// String form rust not contain '\0' in the end, that may cause some problem +// + +pub struct CountingBloomFilter { + bloom: NonNull<counting_bloom_t>, +} + +impl CountingBloomFilter { + pub fn new(capacity: u32, error_rate: f64) -> Result<Self, &'static str> { + let bloom_ptr = unsafe { new_counting_bloom(capacity, error_rate) }; + let bloom = + NonNull::new(bloom_ptr).ok_or("Failed to create CountingBloomFilter,return null")?; + Ok(CountingBloomFilter { bloom }) + } + + // get raw pointer + fn get_raw(&self) -> *mut counting_bloom_t { + self.bloom.as_ptr() + } + + pub fn add<T>(&mut self, key: T) -> Result<(), ()> + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let res = + unsafe { counting_bloom_add(self.get_raw(), key.as_ptr() as *const i8, key.len()) }; + if res == 0 { + Ok(()) + } else { + Err(()) + } + } + + pub fn remove<T>(&mut self, key: T) -> Result<(), ()> + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let res = + unsafe { counting_bloom_remove(self.get_raw(), key.as_ptr() as *const i8, key.len()) }; + if res == 0 { + Ok(()) + } else { + Err(()) + } + } + + pub fn check<T>(&self, key: T) -> bool + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + unsafe { counting_bloom_check(self.get_raw(), key.as_ptr() as *const i8, key.len()) > 0 } + } +} + +impl Drop for CountingBloomFilter { + fn drop(&mut self) { + unsafe { free_counting_bloom(self.get_raw()) }; + } +} + +pub struct ScalingBloomFilter { + bloom: NonNull<scaling_bloom_t>, +} + +impl ScalingBloomFilter { + pub fn new(capacity: u32, error_rate: f64) -> Result<Self, &'static str> { + let bloom_ptr = unsafe { new_scaling_bloom(capacity as c_uint, error_rate) }; + let bloom = + NonNull::new(bloom_ptr).ok_or("Failed to create ScalingBloomFilter,return null")?; + Ok(ScalingBloomFilter { bloom }) + } + + // get raw pointer + fn get_raw(&self) -> *mut scaling_bloom_t { + self.bloom.as_ptr() + } + + pub fn add<T>(&mut self, key: T, id: u64) -> Result<(), ()> + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let res = + unsafe { scaling_bloom_add(self.get_raw(), key.as_ptr() as *const i8, key.len(), id) }; + // scaling_bloom_add returns 1 if the key is ok... + // no idea why it's not 0... + if res == 1 { + Ok(()) + } else { + Err(()) + } + } + + pub fn remove<T>(&mut self, key: T, id: u64) -> Result<(), ()> + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let res = unsafe { + scaling_bloom_remove(self.get_raw(), key.as_ptr() as *const i8, key.len(), id) + }; + // scaling_bloom_remove returns 1 if remove success... + // no idea why it's not 0... + if res == 1 { + Ok(()) + } else { + Err(()) + } + } + + pub fn check<T>(&self, key: T) -> bool + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + unsafe { scaling_bloom_check(self.get_raw(), key.as_ptr() as *const i8, key.len()) > 0 } + } + + pub fn flush(&mut self) -> Result<(), ()> { + let res = unsafe { scaling_bloom_flush(self.get_raw()) }; + if res == 0 { + Ok(()) + } else { + Err(()) + } + } + + pub fn mem_seqnum(&self) -> u64 { + unsafe { scaling_bloom_mem_seqnum(self.get_raw()) } + } + + pub fn disk_seqnum(&self) -> u64 { + unsafe { scaling_bloom_disk_seqnum(self.get_raw()) } + } +} + +impl Drop for ScalingBloomFilter { + fn drop(&mut self) { + unsafe { free_scaling_bloom(self.get_raw()) }; + } +} + +fn get_current_time() -> Result<time_t, c_int> { + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|_| -1)? + .as_secs() as time_t; + Ok(current_time) +} + +pub struct ExpiryBloomFilter { + bloom: NonNull<expiry_dablooms_handle>, +} + +impl ExpiryBloomFilter { + pub fn new(capacity: u32, error_rate: f64, expiry_time: i32) -> Result<Self, &'static str> { + let cur_time = get_current_time(); + let bloom_ptr = unsafe { + expiry_dablooms_init( + capacity as c_uint, + error_rate, + cur_time.unwrap(), + expiry_time as c_int, + ) + }; + let bloom = + NonNull::new(bloom_ptr).ok_or("Failed to create ExpiryBloomFilter,return null")?; + Ok(ExpiryBloomFilter { bloom }) + } + + // get raw pointer + fn get_raw(&self) -> *mut expiry_dablooms_handle { + self.bloom.as_ptr() + } + + pub fn add<T>(&mut self, key: T) -> Result<(), String> + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let cur_time = get_current_time(); + let res = unsafe { + expiry_dablooms_add( + self.get_raw(), + key.as_ptr() as *const i8, + key.len(), + cur_time.unwrap(), + ) + }; + if res == 0 { + Ok(()) + } else { + let err_cstr = unsafe { expiry_dablooms_errno_trans(res) }; + let err_msg = unsafe { CStr::from_ptr(err_cstr).to_string_lossy().into_owned() }; + Err(err_msg) + } + } + + pub fn check<T>(&self, key: T) -> bool + where + T: AsRef<[u8]>, + { + let key = key.as_ref(); + let cur_time = get_current_time(); + unsafe { + expiry_dablooms_search( + self.get_raw(), + key.as_ptr() as *const i8, + key.len(), + cur_time.unwrap(), + ) > 0 + } + } + + pub fn count(&self) -> i32 { + let mut count: u64 = 0; + unsafe { + expiry_dablooms_element_count_get(self.get_raw(), &mut count); + } + return count as i32; + } +} + +impl Drop for ExpiryBloomFilter { + fn drop(&mut self) { + unsafe { expiry_dablooms_destroy(self.get_raw()) }; + } +} + +#[cfg(test)] +mod tests { + use libc::sleep; + + use crate::dablooms::*; + + #[test] + fn test_counting_bloom_filter() { + let mut bloom = CountingBloomFilter::new(1000, 0.01).unwrap(); + let key1 = "hello"; + let key2 = "world"; + let key3 = "rust"; + + assert!(bloom.check(key1) == false); + assert!(bloom.check(key2) == false); + assert!(bloom.check(key3) == false); + bloom.add(key1).unwrap(); + bloom.add(key2).unwrap(); + + drop(key1); + let key = "hello"; + + assert!(bloom.check(key) == true); + assert!(bloom.check(key2) == true); + assert!(bloom.check(key3) == false); + bloom.remove(key).unwrap(); + assert!(bloom.check(key1) == false); + assert!(bloom.check(key2) == true); + assert!(bloom.check(key3) == false); + } + + #[test] + fn test_scaling_bloom_filter() { + let mut bloom = ScalingBloomFilter::new(100, 0.05).unwrap(); + let key1 = "aaa"; + let key2 = "bbb"; + let id1 = 1; + let id2 = 2; + assert!(bloom.check(key1) == false); + assert!(bloom.check(key2) == false); + + bloom.add(key1, id1).unwrap(); + bloom.add(key2, id2).unwrap(); + assert!(bloom.check(key1) == true); + assert!(bloom.check(key2) == true); + bloom.remove(key1, id1).unwrap(); + assert!(bloom.check(key1) == false); + assert!(bloom.check(key2) == true); + } + + #[test] + fn test_expiry_dablooms_filter() { + let capacity: u32 = 100; + let error_rate: f64 = 0.05; + let expiry_secs: i32 = 3; // expiry time 1s + let key1 = "aaa"; + let key2 = "bbb"; + + let mut bloom = ExpiryBloomFilter::new(capacity, error_rate, expiry_secs).unwrap(); + + for _i in 1..7 { + bloom.add(key1).unwrap(); + assert!(bloom.count() == _i); + } + assert!(bloom.check(key1)); + + unsafe { + sleep(2); // sleep 2 sec | all key1's value not expired + } + assert!(bloom.check(key1)); + unsafe { + sleep(2); // sleep 4 sec | all key1's value expired + } + assert!(!bloom.check(key1)); + + for _i in 1..7 { + bloom.add(key2).unwrap(); + assert!(bloom.count() == _i); + } + } +} |
