aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/bin/lambo.rs10
-rw-r--r--src/config.rs4
-rw-r--r--src/handlers/mod.rs5
-rw-r--r--src/handlers/x500_mapper.rs47
4 files changed, 48 insertions, 18 deletions
diff --git a/src/bin/lambo.rs b/src/bin/lambo.rs
index e3e81a6..1134480 100644
--- a/src/bin/lambo.rs
+++ b/src/bin/lambo.rs
@@ -3,7 +3,7 @@ use clap::{value_parser, ArgAction, Parser};
use lambo::{config::Config, handlers::*};
use serenity::{all::GatewayIntents, Client};
use sqlx::sqlite::SqlitePoolOptions;
-use std::path::PathBuf;
+use std::{path::PathBuf, sync::Arc};
use stderrlog::StdErrLog;
#[derive(Debug, Parser)]
@@ -69,7 +69,13 @@ async fn main() -> Result<()> {
})?;
// Create the handlers.
- let handler = MultiHandler(vec![Box::new(PresenceSetter), Box::new(X500Mapper(db))]);
+ let handler = MultiHandler(vec![
+ Box::new(PresenceSetter),
+ Box::new(X500Mapper {
+ config: Arc::new(config.x500_mapper),
+ db,
+ }),
+ ]);
// Start up the client.
let intents = GatewayIntents::default() | GatewayIntents::GUILD_MEMBERS;
diff --git a/src/config.rs b/src/config.rs
index 14a4c19..47695e3 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,11 +1,15 @@
+use crate::handlers::X500MapperConfig;
use anyhow::{Context, Result};
use serde::Deserialize;
use std::{fs, path::Path};
#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
pub struct Config {
pub database_url: String,
pub discord_token: String,
+
+ pub x500_mapper: X500MapperConfig,
}
impl Config {
diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs
index f2b8a4f..46bd57b 100644
--- a/src/handlers/mod.rs
+++ b/src/handlers/mod.rs
@@ -21,7 +21,10 @@ use std::collections::HashMap;
mod presence_setter;
mod x500_mapper;
-pub use self::{presence_setter::PresenceSetter, x500_mapper::X500Mapper};
+pub use self::{
+ presence_setter::PresenceSetter,
+ x500_mapper::{X500Mapper, X500MapperConfig},
+};
/// An EventHandler that proxies events to each of its contained handlers concurrently (but not in
/// parallel).
diff --git a/src/handlers/x500_mapper.rs b/src/handlers/x500_mapper.rs
index c228f70..36ce386 100644
--- a/src/handlers/x500_mapper.rs
+++ b/src/handlers/x500_mapper.rs
@@ -3,14 +3,15 @@ use futures::{
future::{self, Either},
stream, FutureExt, Stream, StreamExt, TryStreamExt,
};
+use serde::Deserialize;
use serenity::{
- all::{GuildInfo, GuildMemberUpdateEvent, Member, Ready, UserId},
+ all::{GuildInfo, GuildMemberUpdateEvent, Member, Ready, RoleId},
async_trait,
client::{Context, EventHandler},
http::Http,
};
use sqlx::SqlitePool;
-use std::time::Duration;
+use std::{sync::Arc, time::Duration};
use time::OffsetDateTime;
use tokio::{
spawn,
@@ -20,13 +21,22 @@ use tokio::{
/// A handler that notices people with an X.500 in their nicknames that matches a student's, and
/// records it in the database.
#[derive(Clone)]
-pub struct X500Mapper(pub SqlitePool);
+pub struct X500Mapper {
+ pub config: Arc<X500MapperConfig>,
+ pub db: SqlitePool,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub struct X500MapperConfig {
+ pub students_role: Option<RoleId>,
+}
impl X500Mapper {
async fn log_stats(&self) {
let result = future::join(
- sqlx::query!("SELECT COUNT(*) AS count FROM uids_to_x500s_clean").fetch_one(&self.0),
- sqlx::query!("SELECT COUNT(*) AS count FROM student_x500s").fetch_one(&self.0),
+ sqlx::query!("SELECT COUNT(*) AS count FROM uids_to_x500s_clean").fetch_one(&self.db),
+ sqlx::query!("SELECT COUNT(*) AS count FROM student_x500s").fetch_one(&self.db),
)
.map(|(result_have, result_total)| -> Result<_> {
Ok((result_have?.count, result_total?.count))
@@ -57,7 +67,7 @@ impl X500Mapper {
// Handle each one.
stream::iter(parsed_members)
- .for_each(|(member, x500)| self.record_x500(member.user.id, x500))
+ .for_each(|(member, x500)| self.record_x500(ctx, member, x500))
.await;
// Print some stats, everybody likes stats!
@@ -78,18 +88,17 @@ impl X500Mapper {
}
}
- async fn record_x500(&self, uid: UserId, x500: String) {
- let x500 = x500.to_lowercase();
+ async fn record_x500(&self, ctx: &Context, member: Member, x500: String) {
let x500 = &x500;
+ let uid = member.user.id;
let future = async move {
let uid = i64::from(uid);
-
sqlx::query!(
"INSERT OR IGNORE INTO all_seen_uids_to_x500s (uid, x500) VALUES (?, ?)",
uid,
x500
)
- .execute(&self.0)
+ .execute(&self.db)
.await
.context("failed to insert into all_seen_uids_to_x500s")?;
@@ -98,10 +107,18 @@ impl X500Mapper {
uid,
x500
)
- .execute(&self.0)
+ .execute(&self.db)
.await
.context("failed to insert into uids_to_x500s_clean")?;
+ if let Some(role) = self.config.students_role {
+ log::info!("adding the role {} to {}", role, member.display_name());
+ member
+ .add_role(&ctx.http, role)
+ .await
+ .context("failed to add student role")?
+ }
+
Ok::<_, anyhow::Error>(())
};
@@ -125,14 +142,14 @@ impl EventHandler for X500Mapper {
async fn guild_member_update(
&self,
- _ctx: Context,
+ ctx: Context,
_old_if_available: Option<Member>,
new: Option<Member>,
_event: GuildMemberUpdateEvent,
) {
if let Some(member) = new {
if let Some(x500) = parse_out_x500(member.display_name()) {
- self.record_x500(member.user.id, x500.to_string()).await;
+ self.record_x500(&ctx, member, x500).await;
self.log_stats().await;
}
}
@@ -173,7 +190,7 @@ async fn get_all_members(http: &Http) -> Vec<Member> {
members
}
-fn parse_out_x500(display_name: &str) -> Option<&str> {
+fn parse_out_x500(display_name: &str) -> Option<String> {
enum State {
BeforeParens,
SawLParen(usize),
@@ -206,7 +223,7 @@ fn parse_out_x500(display_name: &str) -> Option<&str> {
}
match state {
- State::SawBothParens(li, ri) => Some(&display_name[li..ri]),
+ State::SawBothParens(li, ri) => Some(display_name[li..ri].to_lowercase()),
_ => None,
}
}