diff options
-rw-r--r-- | Cargo.lock | 11 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | migrations/20240116060938_init_x500s.sql | 8 | ||||
-rw-r--r-- | src/handlers/x500_mapper.rs | 169 |
4 files changed, 174 insertions, 16 deletions
@@ -1247,6 +1247,15 @@ dependencies = [ ] [[package]] +name = "num_threads" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" +dependencies = [ + "libc", +] + +[[package]] name = "object" version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2149,6 +2158,8 @@ checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", "itoa", + "libc", + "num_threads", "powerfmt", "serde", "time-core", @@ -13,6 +13,6 @@ serde = { version = "1.0.195", features = ["derive"] } serenity = "0.12.0" sqlx = { version = "0.7.3", features = ["runtime-tokio", "sqlite", "time"] } stderrlog = "0.5.4" -time = "0.3.31" +time = { version = "0.3.31", features = ["local-offset"] } tokio = { version = "1.35.1", features = ["rt-multi-thread"] } toml = "0.8.8" diff --git a/migrations/20240116060938_init_x500s.sql b/migrations/20240116060938_init_x500s.sql index 3d5ca68..adc1a4b 100644 --- a/migrations/20240116060938_init_x500s.sql +++ b/migrations/20240116060938_init_x500s.sql @@ -6,8 +6,8 @@ CREATE TABLE IF NOT EXISTS student_x500s -- In the common case, we can just map users to their X.500s. CREATE TABLE IF NOT EXISTS uids_to_x500s_clean - ( uid TEXT NOT NULL - , x500 TEXT NOT NULL + ( uid INTEGER NOT NULL + , x500 TEXT NOT NULL , FOREIGN KEY(x500) REFERENCES student_x500s(x500) , UNIQUE(uid) , UNIQUE(x500) @@ -21,7 +21,7 @@ CREATE TABLE IF NOT EXISTS uids_to_x500s_clean -- all_seen_uids_to_x500s. This lets us get a list of "suspicious tuples" by -- subtracting uids_to_x500s_clean from all_seen_uids_to_x500s. CREATE TABLE IF NOT EXISTS all_seen_uids_to_x500s - ( uid TEXT NOT NULL - , x500 TEXT NOT NULL + ( uid INTEGER NOT NULL + , x500 TEXT NOT NULL , UNIQUE(uid, x500) ); diff --git a/src/handlers/x500_mapper.rs b/src/handlers/x500_mapper.rs index 1dd8955..debb7f1 100644 --- a/src/handlers/x500_mapper.rs +++ b/src/handlers/x500_mapper.rs @@ -1,31 +1,178 @@ +use anyhow::Result; +use futures::{ + future::{self, Either}, + stream, FutureExt, Stream, StreamExt, TryStreamExt, +}; use serenity::{ - all::{GuildMemberUpdateEvent, Member, UserId}, + all::{GuildInfo, GuildMemberUpdateEvent, Member, Ready, UserId}, async_trait, client::{Context, EventHandler}, + http::Http, +}; +use sqlx::SqlitePool; +use std::time::Duration; +use time::OffsetDateTime; +use tokio::{ + spawn, + time::{sleep_until, Instant}, }; -use sqlx::{Database, Pool}; /// A handler that notices people with an X.500 in their nicknames that matches a student's, and /// records it in the database. -pub struct X500Mapper<DB: Database>(pub Pool<DB>); +#[derive(Clone)] +pub struct X500Mapper(pub SqlitePool); -impl<DB: Database> X500Mapper<DB> { - async fn notice_member(&self, nick: &str, uid: UserId) { - dbg!((nick, uid)); +impl X500Mapper { + async fn log_stats(&self) { + let result = future::join( + sqlx::query!("SELECT COUNT(*) AS count FROM uids_to_x500s_clean").fetch_one(&self.0), + sqlx::query!("SELECT COUNT(*) AS count FROM student_x500s").fetch_one(&self.0), + ) + .map(|(result_have, result_total)| -> Result<_> { + Ok((result_have?.count, result_total?.count)) + }) + .await; + + match result { + Ok((have, total)) => log::info!("Now have users for {} / {} known X.500s", have, total), + Err(err) => log::error!("Failed to get stats about X.500s: {}", err), + } + } + + async fn look_for_everyone(&self, ctx: &Context) { + // Get all the members. + let members = get_all_members(&ctx.http).await; + log::info!("Got a list of {} members to consider", members.len()); + + // Filter for the ones whose names have exactly one parenthesized section, at the end, and + // parse it out. + let parsed_members = members + .into_iter() + .filter_map(|member| { + let x500 = parse_out_x500(member.display_name())?.to_string(); + Some((member, x500)) + }) + .collect::<Vec<_>>(); + log::info!("{} members had recognizable X.500s", parsed_members.len()); + + // Handle each one. + stream::iter(parsed_members) + .for_each(|(member, x500)| self.notice_x500(member.user.id, x500)) + .await; + + // Print some stats, everybody likes stats! + self.log_stats().await; + } + + async fn look_for_everyone_loop(self, ctx: Context) { + loop { + let next_start = Instant::now() + Duration::from_secs(12 * 60 * 60); + log::info!("Checking for new students..."); + self.look_for_everyone(&ctx).await; + log::info!( + "Waiting to check for new students until {}", + OffsetDateTime::now_local().unwrap_or_else(|_| OffsetDateTime::now_utc()) + + (next_start - Instant::now()) + ); + sleep_until(next_start).await + } + } + + async fn notice_x500(&self, uid: UserId, x500: String) { + dbg!((uid, x500)); } } #[async_trait] -impl<DB: Database> EventHandler for X500Mapper<DB> { +impl EventHandler for X500Mapper { + async fn ready(&self, ctx: Context, _data_about_bot: Ready) { + spawn(self.clone().look_for_everyone_loop(ctx)); + } + async fn guild_member_update( &self, _ctx: Context, _old_if_available: Option<Member>, - _new: Option<Member>, - event: GuildMemberUpdateEvent, + new: Option<Member>, + _event: GuildMemberUpdateEvent, ) { - if let Some(nick) = event.nick { - self.notice_member(&nick, event.user.id).await + if let Some(member) = new { + if let Some(x500) = parse_out_x500(member.display_name()) { + self.notice_x500(member.user.id, x500.to_string()).await; + self.log_stats().await; + } } } } + +fn get_all_guilds(http: &Http) -> impl '_ + Stream<Item = Result<GuildInfo>> { + // TODO: Paginate me! + http.get_guilds(None, Some(200)) + .map(|result| match result { + Ok(guilds) => Either::Left(stream::iter(guilds).map(Ok)), + Err(err) => Either::Right(stream::once(future::err(err.into()))), + }) + .flatten_stream() +} + +async fn get_all_members(http: &Http) -> Vec<Member> { + let (members, errs) = get_all_guilds(http) + .flat_map(|result| match result { + Ok(guild) => Either::Left(guild.id.members_iter(http).map_err(|err| err.into())), + Err(err) => Either::Right(stream::once(future::err(err))), + }) + .fold((Vec::new(), Vec::new()), |(mut oks, mut errs), result| { + match result { + Ok(member) => oks.push(member), + Err(err) => errs.push(err), + } + future::ready((oks, errs)) + }) + .await; + if let Some(err) = errs.first() { + log::error!( + "failed to get a list of all members: {}; proceeding with {}", + err, + members.len() + ); + } + members +} + +fn parse_out_x500(display_name: &str) -> Option<&str> { + enum State { + BeforeParens, + SawLParen(usize), + SawBothParens(usize, usize), + } + + let mut state = State::BeforeParens; + for (i, ch) in display_name.char_indices() { + state = match (state, ch) { + (State::BeforeParens, '(') => State::SawLParen(i + 1), + (State::SawLParen(li), ')') => State::SawBothParens(li, i), + + (State::SawLParen(_) | State::SawBothParens(_, _), '(') => return None, + (State::BeforeParens | State::SawBothParens(_, _), ')') => return None, + + (State::BeforeParens, _) => State::BeforeParens, + (State::SawLParen(li), ch) => { + if !ch.is_ascii_alphanumeric() { + return None; + } + State::SawLParen(li) + } + (State::SawBothParens(li, ri), ch) => { + if !ch.is_whitespace() { + return None; + } + State::SawBothParens(li, ri) + } + }; + } + + match state { + State::SawBothParens(li, ri) => Some(&display_name[li..ri]), + _ => None, + } +} |