aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNathan Ringo <nathan@remexre.com>2024-01-18 10:58:36 -0600
committerNathan Ringo <nathan@remexre.com>2024-01-18 10:58:36 -0600
commit00d0bfced902e97eeae5257c14134d4bc7efc710 (patch)
treeee026f328614e03aec3ed373d9f2e6c8e255f834 /src
parent7017762a4a38266aa88976be141f7bd663647edc (diff)
Commands to interact with discocaml, associated IPC.
Diffstat (limited to 'src')
-rw-r--r--src/bin/lambo.rs4
-rw-r--r--src/commands/discocaml.rs177
-rw-r--r--src/commands/mod.rs47
-rw-r--r--src/config.rs3
-rw-r--r--src/handlers/commands.rs29
-rw-r--r--src/handlers/mod.rs2
-rw-r--r--src/handlers/x500_mapper.rs11
-rw-r--r--src/lib.rs2
-rw-r--r--src/utils.rs290
9 files changed, 559 insertions, 6 deletions
diff --git a/src/bin/lambo.rs b/src/bin/lambo.rs
index f446308..6b5f349 100644
--- a/src/bin/lambo.rs
+++ b/src/bin/lambo.rs
@@ -58,6 +58,10 @@ async fn main() -> Result<()> {
// Create the handlers.
let handler = MultiHandler(vec![
+ Box::new(Commands {
+ config: config.commands,
+ db: db.clone(),
+ }),
Box::new(PresenceSetter),
Box::new(X500Mapper {
config: Arc::new(config.x500_mapper),
diff --git a/src/commands/discocaml.rs b/src/commands/discocaml.rs
new file mode 100644
index 0000000..93ddfaa
--- /dev/null
+++ b/src/commands/discocaml.rs
@@ -0,0 +1,177 @@
+use anyhow::{anyhow, bail, Context as _, Error, Result};
+use serde::{Deserialize, Serialize};
+use serde_json::de::Deserializer;
+use serenity::{
+ all::{
+ CommandDataOptionValue, CommandInteraction, CommandOptionType, CommandType, Member, RoleId,
+ },
+ builder::{
+ CreateCommand, CreateCommandOption, CreateInteractionResponse,
+ CreateInteractionResponseMessage,
+ },
+ client::Context,
+ model::Permissions,
+};
+use sqlx::SqlitePool;
+use std::process::Stdio;
+use tokio::{io::AsyncWriteExt, process::Command};
+
+use crate::utils::EnumAsArray;
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub struct DiscocamlConfig {
+ pub command: Vec<String>,
+ pub role: RoleId,
+}
+
+#[derive(Debug, Serialize)]
+struct DiscocamlRequest {
+ expr: String,
+ command: DiscocamlCommand,
+}
+
+#[derive(Debug, Serialize)]
+enum DiscocamlCommand {
+ Roundtrip,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+enum DiscocamlResponse {
+ Error(String),
+}
+
+pub fn command() -> CreateCommand {
+ CreateCommand::new("discocaml")
+ .kind(CommandType::ChatInput)
+ .default_member_permissions(Permissions::empty())
+ .dm_permission(true)
+ .description("Sends this expression to the disco!")
+ .add_option(
+ CreateCommandOption::new(
+ CommandOptionType::String,
+ "expr",
+ "The expression to operate on.",
+ )
+ .required(true),
+ )
+}
+
+async fn check_permissions(config: &DiscocamlConfig, member: &Option<Box<Member>>) -> Result<()> {
+ if let Some(member) = member {
+ if !member.roles.contains(&config.role) {
+ bail!("This command can only be used by <@&{}>.", config.role)
+ }
+ Ok(())
+ } else {
+ bail!("This command cannot be used in DMs.")
+ }
+}
+
+async fn respond_with_error(ctx: &Context, command: &CommandInteraction, err: &Error) {
+ let msg = CreateInteractionResponseMessage::new().content(format!(":no_entry_sign: {}", err));
+ if let Err(err) = command
+ .create_response(ctx, CreateInteractionResponse::Message(msg))
+ .await
+ {
+ log::error!(
+ "failed to respond to command that failed permissions check: {:?}",
+ err
+ )
+ }
+}
+
+async fn run_discocaml(
+ config: &DiscocamlConfig,
+ req: &DiscocamlRequest,
+) -> Result<DiscocamlResponse> {
+ let mut child = Command::new(&config.command[0])
+ .args(&config.command)
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .kill_on_drop(true)
+ .spawn()
+ .context("failed to start discocaml")?;
+
+ let mut stdin = child.stdin.take().unwrap();
+ let req = serde_json::to_string(req).context("failed to serialize request to discocaml")?;
+ stdin
+ .write_all(req.as_bytes())
+ .await
+ .context("failed to write request to discocaml")?;
+ stdin
+ .shutdown()
+ .await
+ .context("failed to close pipe to discocaml")?;
+ drop(stdin);
+ drop(req);
+
+ let output = child
+ .wait_with_output()
+ .await
+ .context("failed to wait for discocaml to complete")?;
+ if !output.status.success() {
+ bail!("discocaml exited with non-zero status {:?}", output.status)
+ }
+
+ let mut de = Deserializer::from_slice(&output.stdout);
+ let out = DiscocamlResponse::deserialize(EnumAsArray(&mut de))
+ .context("failed to parse response from discocaml")?;
+ de.end()
+ .context("failed to parse response from discocaml")?;
+ Ok(out)
+}
+
+pub async fn handle_command(
+ ctx: &Context,
+ config: &DiscocamlConfig,
+ _db: &SqlitePool,
+ command: &CommandInteraction,
+) -> Result<()> {
+ // Check that the required role was present.
+ if let Err(err) = check_permissions(config, &command.member).await {
+ respond_with_error(ctx, command, &err).await;
+ return Err(err.context("permissions check failed"));
+ }
+
+ // Parse the expression out.
+ let mut expr = None;
+ for option in &command.data.options {
+ match (&option.name as &str, &option.value) {
+ ("expr", CommandDataOptionValue::String(value)) => expr = Some(value),
+ _ => {
+ let err = anyhow!("unknown option {:?}", option);
+ respond_with_error(ctx, command, &err).await;
+ return Err(err);
+ }
+ }
+ }
+ let expr = if let Some(expr) = expr {
+ expr
+ } else {
+ let err = anyhow!("missing option {:?}", "expr");
+ respond_with_error(ctx, command, &err).await;
+ return Err(err);
+ };
+
+ // Round-trip the expression through discocaml.
+ let req = DiscocamlRequest {
+ expr: expr.to_string(),
+ command: DiscocamlCommand::Roundtrip,
+ };
+ let res = match run_discocaml(config, &req).await {
+ Ok(res) => res,
+ Err(err) => {
+ let err = err.context("failed to run discocaml");
+ respond_with_error(ctx, command, &err).await;
+ return Err(err);
+ }
+ };
+
+ let msg = CreateInteractionResponseMessage::new().content(format!("`{:?}`", res));
+ command
+ .create_response(&ctx, CreateInteractionResponse::Message(msg))
+ .await
+ .context("failed to respond")
+}
diff --git a/src/commands/mod.rs b/src/commands/mod.rs
new file mode 100644
index 0000000..083b200
--- /dev/null
+++ b/src/commands/mod.rs
@@ -0,0 +1,47 @@
+mod discocaml;
+
+use crate::commands::discocaml::DiscocamlConfig;
+use anyhow::{Context as _, Result};
+use serde::Deserialize;
+use serenity::{
+ all::{Command, Interaction},
+ client::Context,
+};
+use sqlx::SqlitePool;
+
+#[derive(Debug, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub struct CommandsConfig {
+ pub discocaml: DiscocamlConfig,
+}
+
+pub async fn set_commands(ctx: &Context) -> Result<()> {
+ Command::set_global_commands(&ctx.http, vec![discocaml::command()])
+ .await
+ .context("failed to set commands")?;
+ Ok(())
+}
+
+pub async fn handle_interaction(
+ ctx: &Context,
+ config: &CommandsConfig,
+ db: &SqlitePool,
+ interaction: &Interaction,
+) -> Result<()> {
+ match interaction {
+ Interaction::Command(cmd) => match &cmd.data.name as &str {
+ "discocaml" => discocaml::handle_command(ctx, &config.discocaml, db, cmd)
+ .await
+ .context("failed to handle discocaml command"),
+ _ => {
+ log::warn!("unexpected interaction: {:?}", interaction);
+ Ok(())
+ }
+ },
+
+ _ => {
+ log::warn!("unexpected interaction: {:?}", interaction);
+ Ok(())
+ }
+ }
+}
diff --git a/src/config.rs b/src/config.rs
index 47695e3..11742d9 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,4 +1,4 @@
-use crate::handlers::X500MapperConfig;
+use crate::{commands::CommandsConfig, handlers::X500MapperConfig};
use anyhow::{Context, Result};
use serde::Deserialize;
use std::{fs, path::Path};
@@ -9,6 +9,7 @@ pub struct Config {
pub database_url: String,
pub discord_token: String,
+ pub commands: CommandsConfig,
pub x500_mapper: X500MapperConfig,
}
diff --git a/src/handlers/commands.rs b/src/handlers/commands.rs
new file mode 100644
index 0000000..d07bdcb
--- /dev/null
+++ b/src/handlers/commands.rs
@@ -0,0 +1,29 @@
+use crate::commands::{handle_interaction, set_commands, CommandsConfig};
+use serenity::{
+ all::{Interaction, Ready},
+ async_trait,
+ client::{Context, EventHandler},
+};
+use sqlx::SqlitePool;
+
+/// A handler that sets up the commands.
+pub struct Commands {
+ pub config: CommandsConfig,
+ pub db: SqlitePool,
+}
+
+#[async_trait]
+impl EventHandler for Commands {
+ async fn ready(&self, ctx: Context, _data_about_bot: Ready) {
+ if let Err(err) = set_commands(&ctx).await {
+ log::error!("failed to set commands: {:?}", err)
+ }
+ }
+
+ async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
+ if let Err(err) = handle_interaction(&ctx, &self.config, &self.db, &interaction).await {
+ log::error!("failed to handle interaction: {:?}", err);
+ log::error!("failed interaction was: {:?}", interaction);
+ }
+ }
+}
diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs
index 46bd57b..4c09fec 100644
--- a/src/handlers/mod.rs
+++ b/src/handlers/mod.rs
@@ -18,10 +18,12 @@ use serenity::{
};
use std::collections::HashMap;
+mod commands;
mod presence_setter;
mod x500_mapper;
pub use self::{
+ commands::Commands,
presence_setter::PresenceSetter,
x500_mapper::{X500Mapper, X500MapperConfig},
};
diff --git a/src/handlers/x500_mapper.rs b/src/handlers/x500_mapper.rs
index 1ec35f2..acd498f 100644
--- a/src/handlers/x500_mapper.rs
+++ b/src/handlers/x500_mapper.rs
@@ -90,9 +90,9 @@ impl X500Mapper {
async fn record_x500(&self, ctx: &Context, member: Member, x500: String) {
let x500 = &x500;
- let uid = member.user.id;
+ let member = &member;
let future = async move {
- let uid = i64::from(uid);
+ let uid = i64::from(member.user.id);
sqlx::query!(
"INSERT OR IGNORE INTO all_seen_uids_to_x500s (uid, x500) VALUES (?, ?)",
uid,
@@ -112,8 +112,8 @@ impl X500Mapper {
.context("failed to insert into uids_to_x500s_clean")?;
if let Some(role) = self.config.students_role {
- log::info!("adding the role {} to {}", role, member.display_name());
if !member.roles.contains(&role) {
+ log::info!("adding the role {} to {:?}", role, member.display_name());
member
.add_role(&ctx.http, role)
.await
@@ -127,8 +127,9 @@ impl X500Mapper {
match future.await {
Ok(()) => (),
Err(err) => log::error!(
- "failed to record that the user with UID {} had the X.500 {}: {:?}",
- uid,
+ "failed to record that the user with UID {} ({:?}) had the X.500 {}: {:?}",
+ member.user.id,
+ member.display_name(),
x500,
err
),
diff --git a/src/lib.rs b/src/lib.rs
index 31c4954..748eee4 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,5 +1,7 @@
+pub mod commands;
pub mod config;
pub mod handlers;
+pub mod utils;
use anyhow::{Context, Result};
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 0000000..c8b3457
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,290 @@
+use serde::{
+ de::{self, DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor},
+ forward_to_deserialize_any,
+ ser::{self, Impossible, Serialize, SerializeSeq, SerializeTupleVariant, Serializer},
+};
+use std::fmt;
+
+/// A wrapper for serializing and deserializing enums in the way Yojson expects. Based off of
+/// [this](https://github.com/serde-rs/serde/issues/979#issuecomment-350340431).
+pub struct EnumAsArray<T>(pub T);
+
+fn expected_enum<T, E>() -> Result<T, E>
+where
+ E: ser::Error,
+{
+ Err(ser::Error::custom(
+ "EnumAsArray expected to operate on an enum type",
+ ))
+}
+
+impl<S> Serializer for EnumAsArray<S>
+where
+ S: Serializer,
+{
+ type Ok = S::Ok;
+ type Error = S::Error;
+
+ type SerializeSeq = Impossible<Self::Ok, Self::Error>;
+ type SerializeTuple = Impossible<Self::Ok, Self::Error>;
+ type SerializeTupleStruct = Impossible<Self::Ok, Self::Error>;
+ type SerializeTupleVariant = EnumAsArray<S::SerializeSeq>;
+ type SerializeMap = Impossible<Self::Ok, Self::Error>;
+ type SerializeStruct = Impossible<Self::Ok, Self::Error>;
+ type SerializeStructVariant = Impossible<Self::Ok, Self::Error>;
+
+ fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_some<T: ?Sized + Serialize>(self, _value: &T) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_unit_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ variant: &'static str,
+ ) -> Result<Self::Ok, Self::Error> {
+ let mut seq = self.0.serialize_seq(Some(1))?;
+ seq.serialize_element(variant)?;
+ seq.end()
+ }
+ fn serialize_newtype_struct<T: ?Sized + Serialize>(
+ self,
+ _name: &'static str,
+ _value: &T,
+ ) -> Result<Self::Ok, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_newtype_variant<T: ?Sized + Serialize>(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ variant: &'static str,
+ value: &T,
+ ) -> Result<Self::Ok, Self::Error> {
+ let mut seq = self.0.serialize_seq(Some(2))?;
+ seq.serialize_element(variant)?;
+ seq.serialize_element(value)?;
+ seq.end()
+ }
+ fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_tuple_struct(
+ self,
+ _name: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeTupleStruct, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_tuple_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ variant: &'static str,
+ len: usize,
+ ) -> Result<Self::SerializeTupleVariant, Self::Error> {
+ let mut seq = self.0.serialize_seq(Some(1 + len))?;
+ seq.serialize_element(variant)?;
+ Ok(EnumAsArray(seq))
+ }
+ fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_struct(
+ self,
+ _name: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeStruct, Self::Error> {
+ expected_enum()
+ }
+ fn serialize_struct_variant(
+ self,
+ _name: &'static str,
+ _variant_index: u32,
+ _variant: &'static str,
+ _len: usize,
+ ) -> Result<Self::SerializeStructVariant, Self::Error> {
+ Err(ser::Error::custom(
+ "enum_as_array::serialize does not support struct variants",
+ ))
+ }
+}
+
+impl<S> SerializeTupleVariant for EnumAsArray<S>
+where
+ S: SerializeSeq,
+{
+ type Ok = S::Ok;
+ type Error = S::Error;
+
+ fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> {
+ self.0.serialize_element(value)
+ }
+ fn end(self) -> Result<Self::Ok, Self::Error> {
+ self.0.end()
+ }
+}
+
+impl<'de, D> Deserializer<'de> for EnumAsArray<D>
+where
+ D: Deserializer<'de>,
+{
+ type Error = D::Error;
+
+ fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
+ where
+ V: Visitor<'de>,
+ {
+ Err(de::Error::custom(
+ "EnumAsArray expected to operate on an enum type",
+ ))
+ }
+
+ fn deserialize_enum<V>(
+ self,
+ _name: &'static str,
+ _variants: &'static [&'static str],
+ visitor: V,
+ ) -> Result<V::Value, Self::Error>
+ where
+ V: Visitor<'de>,
+ {
+ self.0.deserialize_seq(EnumAsArray(visitor))
+ }
+
+ forward_to_deserialize_any! {
+ bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
+ byte_buf option unit unit_struct newtype_struct seq tuple
+ tuple_struct map struct identifier ignored_any
+ }
+}
+
+impl<'de, V> Visitor<'de> for EnumAsArray<V>
+where
+ V: Visitor<'de>,
+{
+ type Value = V::Value;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ self.0.expecting(formatter)
+ }
+
+ fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
+ where
+ A: SeqAccess<'de>,
+ {
+ self.0.visit_enum(EnumAsArray(seq))
+ }
+}
+
+impl<'de, A> EnumAccess<'de> for EnumAsArray<A>
+where
+ A: SeqAccess<'de>,
+{
+ type Error = A::Error;
+ type Variant = Self;
+
+ fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
+ where
+ V: DeserializeSeed<'de>,
+ {
+ match self.0.next_element_seed(seed)? {
+ Some(first) => Ok((first, self)),
+ None => Err(de::Error::custom("expected at least one element")),
+ }
+ }
+}
+
+impl<'de, A> VariantAccess<'de> for EnumAsArray<A>
+where
+ A: SeqAccess<'de>,
+{
+ type Error = A::Error;
+
+ fn unit_variant(self) -> Result<(), Self::Error> {
+ Ok(())
+ }
+
+ fn newtype_variant_seed<T>(mut self, seed: T) -> Result<T::Value, Self::Error>
+ where
+ T: DeserializeSeed<'de>,
+ {
+ match self.0.next_element_seed(seed)? {
+ Some(newtype) => Ok(newtype),
+ None => Err(de::Error::custom("missing newtype variant value")),
+ }
+ }
+
+ fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
+ where
+ V: Visitor<'de>,
+ {
+ visitor.visit_seq(self.0)
+ }
+
+ fn struct_variant<V>(
+ self,
+ _fields: &'static [&'static str],
+ _visitor: V,
+ ) -> Result<V::Value, Self::Error>
+ where
+ V: Visitor<'de>,
+ {
+ Err(de::Error::custom(
+ "EnumAsArray does not support struct variants",
+ ))
+ }
+}