use anyhow::{Context as _, Result}; use futures::{ future::{self, Either}, stream, FutureExt, Stream, StreamExt, TryStreamExt, }; use serde::Deserialize; use serenity::{ all::{GuildInfo, GuildMemberUpdateEvent, Member, Ready, RoleId}, async_trait, client::{Context, EventHandler}, http::Http, }; use sqlx::SqlitePool; use std::{sync::Arc, time::Duration}; use time::OffsetDateTime; use tokio::{ spawn, time::{sleep_until, Instant}, }; /// A handler that notices people with an X.500 in their nicknames that matches a student's, and /// records it in the database. #[derive(Clone)] pub struct X500Mapper { pub config: Arc, pub db: SqlitePool, } #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] pub struct X500MapperConfig { pub students_role: Option, } impl X500Mapper { async fn log_stats(&self) { let result = future::join( sqlx::query!("SELECT COUNT(*) AS count FROM uids_to_x500s_clean").fetch_one(&self.db), sqlx::query!("SELECT COUNT(*) AS count FROM student_x500s").fetch_one(&self.db), ) .map(|(result_have, result_total)| -> Result<_> { Ok((result_have?.count, result_total?.count)) }) .await; match result { Ok((have, total)) => log::info!("Now have users for {} / {} known X.500s", have, total), Err(err) => log::error!("Failed to get stats about X.500s: {:?}", err), } } async fn look_for_everyone(&self, ctx: &Context) { // Get all the members. let members = get_all_members(&ctx.http).await; log::info!("Got a list of {} members to consider", members.len()); // Filter for the ones whose names have exactly one parenthesized section, at the end, and // parse it out. let parsed_members = members .into_iter() .filter_map(|member| { let x500 = parse_out_x500(member.display_name())?.to_string(); Some((member, x500)) }) .collect::>(); log::info!("{} members had recognizable X.500s", parsed_members.len()); // Handle each one. stream::iter(parsed_members) .for_each(|(member, x500)| self.record_x500(ctx, member, x500)) .await; // Print some stats, everybody likes stats! self.log_stats().await; } async fn look_for_everyone_loop(self, ctx: Context) { loop { let next_start = Instant::now() + Duration::from_secs(12 * 60 * 60); log::info!("Checking for new students..."); self.look_for_everyone(&ctx).await; log::info!( "Waiting to check for new students until {}", OffsetDateTime::now_local().unwrap_or_else(|_| OffsetDateTime::now_utc()) + (next_start - Instant::now()) ); sleep_until(next_start).await } } async fn record_x500(&self, ctx: &Context, member: Member, x500: String) { let x500 = &x500; let uid = member.user.id; let future = async move { let uid = i64::from(uid); sqlx::query!( "INSERT OR IGNORE INTO all_seen_uids_to_x500s (uid, x500) VALUES (?, ?)", uid, x500 ) .execute(&self.db) .await .context("failed to insert into all_seen_uids_to_x500s")?; sqlx::query!( "INSERT OR IGNORE INTO uids_to_x500s_clean (uid, x500) VALUES (?, ?)", uid, x500 ) .execute(&self.db) .await .context("failed to insert into uids_to_x500s_clean")?; if let Some(role) = self.config.students_role { log::info!("adding the role {} to {}", role, member.display_name()); member .add_role(&ctx.http, role) .await .context("failed to add student role")? } Ok::<_, anyhow::Error>(()) }; match future.await { Ok(()) => (), Err(err) => log::error!( "failed to record that the user with UID {} had the X.500 {}: {:?}", uid, x500, err ), } } } #[async_trait] impl EventHandler for X500Mapper { async fn ready(&self, ctx: Context, _data_about_bot: Ready) { spawn(self.clone().look_for_everyone_loop(ctx)); } async fn guild_member_update( &self, ctx: Context, _old_if_available: Option, new: Option, _event: GuildMemberUpdateEvent, ) { if let Some(member) = new { if let Some(x500) = parse_out_x500(member.display_name()) { self.record_x500(&ctx, member, x500).await; self.log_stats().await; } } } } fn get_all_guilds(http: &Http) -> impl '_ + Stream> { // TODO: Paginate me! http.get_guilds(None, Some(200)) .map(|result| match result { Ok(guilds) => Either::Left(stream::iter(guilds).map(Ok)), Err(err) => Either::Right(stream::once(future::err(err.into()))), }) .flatten_stream() } async fn get_all_members(http: &Http) -> Vec { let (members, errs) = get_all_guilds(http) .flat_map(|result| match result { Ok(guild) => Either::Left(guild.id.members_iter(http).map_err(|err| err.into())), Err(err) => Either::Right(stream::once(future::err(err))), }) .fold((Vec::new(), Vec::new()), |(mut oks, mut errs), result| { match result { Ok(member) => oks.push(member), Err(err) => errs.push(err), } future::ready((oks, errs)) }) .await; if let Some(err) = errs.first() { log::error!( "failed to get a list of all members: {:?}; proceeding with {}", err, members.len() ); } members } fn parse_out_x500(display_name: &str) -> Option { enum State { BeforeParens, SawLParen(usize), SawBothParens(usize, usize), } let mut state = State::BeforeParens; for (i, ch) in display_name.char_indices() { state = match (state, ch) { (State::BeforeParens, '(') => State::SawLParen(i + 1), (State::SawLParen(li), ')') => State::SawBothParens(li, i), (State::SawLParen(_) | State::SawBothParens(_, _), '(') => return None, (State::BeforeParens | State::SawBothParens(_, _), ')') => return None, (State::BeforeParens, _) => State::BeforeParens, (State::SawLParen(li), ch) => { if !ch.is_ascii_alphanumeric() { return None; } State::SawLParen(li) } (State::SawBothParens(li, ri), ch) => { if !ch.is_whitespace() { return None; } State::SawBothParens(li, ri) } }; } match state { State::SawBothParens(li, ri) => Some(display_name[li..ri].to_lowercase()), _ => None, } }