diff options
author | Nathan Ringo <nathan@remexre.com> | 2024-01-18 10:58:36 -0600 |
---|---|---|
committer | Nathan Ringo <nathan@remexre.com> | 2024-01-18 10:58:36 -0600 |
commit | 00d0bfced902e97eeae5257c14134d4bc7efc710 (patch) | |
tree | ee026f328614e03aec3ed373d9f2e6c8e255f834 | |
parent | 7017762a4a38266aa88976be141f7bd663647edc (diff) |
Commands to interact with discocaml, associated IPC.
-rw-r--r-- | Cargo.lock | 11 | ||||
-rw-r--r-- | Cargo.toml | 3 | ||||
-rw-r--r-- | discocaml/default.nix | 6 | ||||
-rw-r--r-- | discocaml/dune | 6 | ||||
-rw-r--r-- | discocaml/main.ml | 27 | ||||
-rw-r--r-- | flake.nix | 13 | ||||
-rwxr-xr-x | sandboxed-discocaml.sh | 21 | ||||
-rw-r--r-- | src/bin/lambo.rs | 4 | ||||
-rw-r--r-- | src/commands/discocaml.rs | 177 | ||||
-rw-r--r-- | src/commands/mod.rs | 47 | ||||
-rw-r--r-- | src/config.rs | 3 | ||||
-rw-r--r-- | src/handlers/commands.rs | 29 | ||||
-rw-r--r-- | src/handlers/mod.rs | 2 | ||||
-rw-r--r-- | src/handlers/x500_mapper.rs | 11 | ||||
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/utils.rs | 290 |
16 files changed, 632 insertions, 20 deletions
@@ -1014,6 +1014,7 @@ dependencies = [ "futures", "log", "serde", + "serde_json", "serenity", "simple_logger", "sqlx", @@ -1709,6 +1710,15 @@ dependencies = [ ] [[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] name = "signature" version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2176,6 +2186,7 @@ dependencies = [ "mio", "num_cpus", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", @@ -10,9 +10,10 @@ csv = "1.3.0" futures = "0.3.30" log = "0.4.20" serde = { version = "1.0.195", features = ["derive"] } +serde_json = "1.0.111" serenity = "0.12.0" simple_logger = "4.3.3" sqlx = { version = "0.7.3", features = ["runtime-tokio", "sqlite", "time"] } time = { version = "0.3.31", features = ["local-offset"] } -tokio = { version = "1.35.1", features = ["rt-multi-thread"] } +tokio = { version = "1.35.1", features = ["process", "rt-multi-thread"] } toml = "0.8.8" diff --git a/discocaml/default.nix b/discocaml/default.nix index 0045cdf..392c146 100644 --- a/discocaml/default.nix +++ b/discocaml/default.nix @@ -1,10 +1,10 @@ -{ buildDunePackage, ocaml-compiler-libs, ppx_deriving, ppx_import, ppxlib }: +{ buildDunePackage, yojson, ocaml-compiler-libs, ppx_deriving +, ppx_deriving_yojson, ppx_import, ppxlib }: buildDunePackage { pname = "discocaml"; version = "0.1.0"; minimalOcamlVersion = "5.1"; src = ./.; - nativeBuildInputs = [ ]; - buildInputs = [ ppx_deriving ppx_import ppxlib ]; + buildInputs = [ ppx_deriving ppx_deriving_yojson ppx_import ppxlib yojson ]; } diff --git a/discocaml/dune b/discocaml/dune index 02fcc09..695d78f 100644 --- a/discocaml/dune +++ b/discocaml/dune @@ -1,7 +1,9 @@ (executable - (libraries compiler-libs.common) + (flags + (:standard -cclib -static -cclib -lm)) + (libraries compiler-libs.common yojson) (name main) (package discocaml) (preprocess - (staged_pps ppx_import ppx_deriving.show)) + (staged_pps ppx_import ppx_deriving.show ppx_deriving_yojson)) (public_name discocaml)) diff --git a/discocaml/main.ml b/discocaml/main.ml index 2cc97b5..d9895f1 100644 --- a/discocaml/main.ml +++ b/discocaml/main.ml @@ -1,3 +1,14 @@ +type command = [ `Roundtrip ] + +let command_of_yojson = function + | `String "Roundtrip" -> Ok `Roundtrip + | _ -> Error "invalid command" + +type request = { expr : string; command : command } +[@@deriving of_yojson { exn = true }] + +type response = [ `Error of string ] [@@deriving to_yojson { exn = true }] + (* type position = [%import: Lexing.position] [@@deriving show] @@ -14,14 +25,22 @@ type structure_item_desc = [%import: Parsetree.structure_item_desc] type structure_item = [%import: Parsetree.structure_item] [@@deriving show] type structure = [%import: Parsetree.structure] [@@deriving show] type toplevel_phrase = [%import: Parsetree.toplevel_phrase] [@@deriving show] -*) let parse ~path (src : string) = let buf = Lexing.from_string src in buf.lex_start_p <- { buf.lex_start_p with pos_fname = path }; buf.lex_curr_p <- { buf.lex_curr_p with pos_fname = path }; - Parse.use_file buf + Parse.expression buf + +let () = + parse ~path:"main.ml" " print_endline ((\"Hello, world!\") )" + |> Format.fprintf Format.std_formatter "\n%a\n" Pprintast.expression +*) + +let handle_request { expr; command } : response = + match command with `Roundtrip -> `Error ("TODO: " ^ expr) let () = - parse ~path:"main.ml" "let () = print_endline ((\"Hello, world!\") )" - |> List.iter (Pprintast.toplevel_phrase Format.std_formatter) + Yojson.Safe.from_channel stdin + |> request_of_yojson_exn |> handle_request |> response_to_yojson + |> Yojson.Safe.to_channel stdout @@ -5,10 +5,10 @@ }; outputs = { self, fenix, flake-utils, nixpkgs }: - flake-utils.lib.eachDefaultSystem (system: + flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: let pkgs = nixpkgs.legacyPackages.${system}; - ocamlPkgs = pkgs.ocaml-ng.ocamlPackages_5_1; + ocamlPkgs = pkgs.pkgsMusl.ocaml-ng.ocamlPackages_5_1; toolchain = fenix.packages.${system}.stable.withComponents [ "cargo" "rustc" @@ -21,8 +21,13 @@ in rec { devShells.default = pkgs.mkShell { inputsFrom = builtins.attrValues packages; - nativeBuildInputs = - [ pkgs.cargo-watch ocamlPkgs.ocaml-lsp pkgs.sqlite pkgs.sqlx-cli ]; + nativeBuildInputs = [ + pkgs.bubblewrap + pkgs.cargo-watch + ocamlPkgs.ocaml-lsp + pkgs.sqlite + pkgs.sqlx-cli + ]; }; packages = { diff --git a/sandboxed-discocaml.sh b/sandboxed-discocaml.sh new file mode 100755 index 0000000..04bd744 --- /dev/null +++ b/sandboxed-discocaml.sh @@ -0,0 +1,21 @@ +#!/bin/sh +set -eu + +tmp=$(mktemp) +cleanup() +{ + if [ -e "$tmp" ]; then + rm "$tmp" + fi +} +trap cleanup EXIT + +rm "$tmp" +nix build -o "$tmp" .#discocaml +bindir="$(realpath "$tmp")/bin" + +exec \ +timeout 10 \ +env -i \ +"$(which bwrap)" --unshare-all --ro-bind "$bindir" "/" \ +"/discocaml" "$@" diff --git a/src/bin/lambo.rs b/src/bin/lambo.rs index f446308..6b5f349 100644 --- a/src/bin/lambo.rs +++ b/src/bin/lambo.rs @@ -58,6 +58,10 @@ async fn main() -> Result<()> { // Create the handlers. let handler = MultiHandler(vec![ + Box::new(Commands { + config: config.commands, + db: db.clone(), + }), Box::new(PresenceSetter), Box::new(X500Mapper { config: Arc::new(config.x500_mapper), diff --git a/src/commands/discocaml.rs b/src/commands/discocaml.rs new file mode 100644 index 0000000..93ddfaa --- /dev/null +++ b/src/commands/discocaml.rs @@ -0,0 +1,177 @@ +use anyhow::{anyhow, bail, Context as _, Error, Result}; +use serde::{Deserialize, Serialize}; +use serde_json::de::Deserializer; +use serenity::{ + all::{ + CommandDataOptionValue, CommandInteraction, CommandOptionType, CommandType, Member, RoleId, + }, + builder::{ + CreateCommand, CreateCommandOption, CreateInteractionResponse, + CreateInteractionResponseMessage, + }, + client::Context, + model::Permissions, +}; +use sqlx::SqlitePool; +use std::process::Stdio; +use tokio::{io::AsyncWriteExt, process::Command}; + +use crate::utils::EnumAsArray; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct DiscocamlConfig { + pub command: Vec<String>, + pub role: RoleId, +} + +#[derive(Debug, Serialize)] +struct DiscocamlRequest { + expr: String, + command: DiscocamlCommand, +} + +#[derive(Debug, Serialize)] +enum DiscocamlCommand { + Roundtrip, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +enum DiscocamlResponse { + Error(String), +} + +pub fn command() -> CreateCommand { + CreateCommand::new("discocaml") + .kind(CommandType::ChatInput) + .default_member_permissions(Permissions::empty()) + .dm_permission(true) + .description("Sends this expression to the disco!") + .add_option( + CreateCommandOption::new( + CommandOptionType::String, + "expr", + "The expression to operate on.", + ) + .required(true), + ) +} + +async fn check_permissions(config: &DiscocamlConfig, member: &Option<Box<Member>>) -> Result<()> { + if let Some(member) = member { + if !member.roles.contains(&config.role) { + bail!("This command can only be used by <@&{}>.", config.role) + } + Ok(()) + } else { + bail!("This command cannot be used in DMs.") + } +} + +async fn respond_with_error(ctx: &Context, command: &CommandInteraction, err: &Error) { + let msg = CreateInteractionResponseMessage::new().content(format!(":no_entry_sign: {}", err)); + if let Err(err) = command + .create_response(ctx, CreateInteractionResponse::Message(msg)) + .await + { + log::error!( + "failed to respond to command that failed permissions check: {:?}", + err + ) + } +} + +async fn run_discocaml( + config: &DiscocamlConfig, + req: &DiscocamlRequest, +) -> Result<DiscocamlResponse> { + let mut child = Command::new(&config.command[0]) + .args(&config.command) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .context("failed to start discocaml")?; + + let mut stdin = child.stdin.take().unwrap(); + let req = serde_json::to_string(req).context("failed to serialize request to discocaml")?; + stdin + .write_all(req.as_bytes()) + .await + .context("failed to write request to discocaml")?; + stdin + .shutdown() + .await + .context("failed to close pipe to discocaml")?; + drop(stdin); + drop(req); + + let output = child + .wait_with_output() + .await + .context("failed to wait for discocaml to complete")?; + if !output.status.success() { + bail!("discocaml exited with non-zero status {:?}", output.status) + } + + let mut de = Deserializer::from_slice(&output.stdout); + let out = DiscocamlResponse::deserialize(EnumAsArray(&mut de)) + .context("failed to parse response from discocaml")?; + de.end() + .context("failed to parse response from discocaml")?; + Ok(out) +} + +pub async fn handle_command( + ctx: &Context, + config: &DiscocamlConfig, + _db: &SqlitePool, + command: &CommandInteraction, +) -> Result<()> { + // Check that the required role was present. + if let Err(err) = check_permissions(config, &command.member).await { + respond_with_error(ctx, command, &err).await; + return Err(err.context("permissions check failed")); + } + + // Parse the expression out. + let mut expr = None; + for option in &command.data.options { + match (&option.name as &str, &option.value) { + ("expr", CommandDataOptionValue::String(value)) => expr = Some(value), + _ => { + let err = anyhow!("unknown option {:?}", option); + respond_with_error(ctx, command, &err).await; + return Err(err); + } + } + } + let expr = if let Some(expr) = expr { + expr + } else { + let err = anyhow!("missing option {:?}", "expr"); + respond_with_error(ctx, command, &err).await; + return Err(err); + }; + + // Round-trip the expression through discocaml. + let req = DiscocamlRequest { + expr: expr.to_string(), + command: DiscocamlCommand::Roundtrip, + }; + let res = match run_discocaml(config, &req).await { + Ok(res) => res, + Err(err) => { + let err = err.context("failed to run discocaml"); + respond_with_error(ctx, command, &err).await; + return Err(err); + } + }; + + let msg = CreateInteractionResponseMessage::new().content(format!("`{:?}`", res)); + command + .create_response(&ctx, CreateInteractionResponse::Message(msg)) + .await + .context("failed to respond") +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs new file mode 100644 index 0000000..083b200 --- /dev/null +++ b/src/commands/mod.rs @@ -0,0 +1,47 @@ +mod discocaml; + +use crate::commands::discocaml::DiscocamlConfig; +use anyhow::{Context as _, Result}; +use serde::Deserialize; +use serenity::{ + all::{Command, Interaction}, + client::Context, +}; +use sqlx::SqlitePool; + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct CommandsConfig { + pub discocaml: DiscocamlConfig, +} + +pub async fn set_commands(ctx: &Context) -> Result<()> { + Command::set_global_commands(&ctx.http, vec![discocaml::command()]) + .await + .context("failed to set commands")?; + Ok(()) +} + +pub async fn handle_interaction( + ctx: &Context, + config: &CommandsConfig, + db: &SqlitePool, + interaction: &Interaction, +) -> Result<()> { + match interaction { + Interaction::Command(cmd) => match &cmd.data.name as &str { + "discocaml" => discocaml::handle_command(ctx, &config.discocaml, db, cmd) + .await + .context("failed to handle discocaml command"), + _ => { + log::warn!("unexpected interaction: {:?}", interaction); + Ok(()) + } + }, + + _ => { + log::warn!("unexpected interaction: {:?}", interaction); + Ok(()) + } + } +} diff --git a/src/config.rs b/src/config.rs index 47695e3..11742d9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use crate::handlers::X500MapperConfig; +use crate::{commands::CommandsConfig, handlers::X500MapperConfig}; use anyhow::{Context, Result}; use serde::Deserialize; use std::{fs, path::Path}; @@ -9,6 +9,7 @@ pub struct Config { pub database_url: String, pub discord_token: String, + pub commands: CommandsConfig, pub x500_mapper: X500MapperConfig, } diff --git a/src/handlers/commands.rs b/src/handlers/commands.rs new file mode 100644 index 0000000..d07bdcb --- /dev/null +++ b/src/handlers/commands.rs @@ -0,0 +1,29 @@ +use crate::commands::{handle_interaction, set_commands, CommandsConfig}; +use serenity::{ + all::{Interaction, Ready}, + async_trait, + client::{Context, EventHandler}, +}; +use sqlx::SqlitePool; + +/// A handler that sets up the commands. +pub struct Commands { + pub config: CommandsConfig, + pub db: SqlitePool, +} + +#[async_trait] +impl EventHandler for Commands { + async fn ready(&self, ctx: Context, _data_about_bot: Ready) { + if let Err(err) = set_commands(&ctx).await { + log::error!("failed to set commands: {:?}", err) + } + } + + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + if let Err(err) = handle_interaction(&ctx, &self.config, &self.db, &interaction).await { + log::error!("failed to handle interaction: {:?}", err); + log::error!("failed interaction was: {:?}", interaction); + } + } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 46bd57b..4c09fec 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -18,10 +18,12 @@ use serenity::{ }; use std::collections::HashMap; +mod commands; mod presence_setter; mod x500_mapper; pub use self::{ + commands::Commands, presence_setter::PresenceSetter, x500_mapper::{X500Mapper, X500MapperConfig}, }; diff --git a/src/handlers/x500_mapper.rs b/src/handlers/x500_mapper.rs index 1ec35f2..acd498f 100644 --- a/src/handlers/x500_mapper.rs +++ b/src/handlers/x500_mapper.rs @@ -90,9 +90,9 @@ impl X500Mapper { async fn record_x500(&self, ctx: &Context, member: Member, x500: String) { let x500 = &x500; - let uid = member.user.id; + let member = &member; let future = async move { - let uid = i64::from(uid); + let uid = i64::from(member.user.id); sqlx::query!( "INSERT OR IGNORE INTO all_seen_uids_to_x500s (uid, x500) VALUES (?, ?)", uid, @@ -112,8 +112,8 @@ impl X500Mapper { .context("failed to insert into uids_to_x500s_clean")?; if let Some(role) = self.config.students_role { - log::info!("adding the role {} to {}", role, member.display_name()); if !member.roles.contains(&role) { + log::info!("adding the role {} to {:?}", role, member.display_name()); member .add_role(&ctx.http, role) .await @@ -127,8 +127,9 @@ impl X500Mapper { match future.await { Ok(()) => (), Err(err) => log::error!( - "failed to record that the user with UID {} had the X.500 {}: {:?}", - uid, + "failed to record that the user with UID {} ({:?}) had the X.500 {}: {:?}", + member.user.id, + member.display_name(), x500, err ), @@ -1,5 +1,7 @@ +pub mod commands; pub mod config; pub mod handlers; +pub mod utils; use anyhow::{Context, Result}; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..c8b3457 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,290 @@ +use serde::{ + de::{self, DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, + forward_to_deserialize_any, + ser::{self, Impossible, Serialize, SerializeSeq, SerializeTupleVariant, Serializer}, +}; +use std::fmt; + +/// A wrapper for serializing and deserializing enums in the way Yojson expects. Based off of +/// [this](https://github.com/serde-rs/serde/issues/979#issuecomment-350340431). +pub struct EnumAsArray<T>(pub T); + +fn expected_enum<T, E>() -> Result<T, E> +where + E: ser::Error, +{ + Err(ser::Error::custom( + "EnumAsArray expected to operate on an enum type", + )) +} + +impl<S> Serializer for EnumAsArray<S> +where + S: Serializer, +{ + type Ok = S::Ok; + type Error = S::Error; + + type SerializeSeq = Impossible<Self::Ok, Self::Error>; + type SerializeTuple = Impossible<Self::Ok, Self::Error>; + type SerializeTupleStruct = Impossible<Self::Ok, Self::Error>; + type SerializeTupleVariant = EnumAsArray<S::SerializeSeq>; + type SerializeMap = Impossible<Self::Ok, Self::Error>; + type SerializeStruct = Impossible<Self::Ok, Self::Error>; + type SerializeStructVariant = Impossible<Self::Ok, Self::Error>; + + fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_i16(self, _v: i16) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_i32(self, _v: i32) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_i64(self, _v: i64) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_u8(self, _v: u8) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_u16(self, _v: u16) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_u64(self, _v: u64) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_f32(self, _v: f32) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_f64(self, _v: f64) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_str(self, _v: &str) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_bytes(self, _v: &[u8]) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_none(self) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_some<T: ?Sized + Serialize>(self, _value: &T) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_unit(self) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result<Self::Ok, Self::Error> { + let mut seq = self.0.serialize_seq(Some(1))?; + seq.serialize_element(variant)?; + seq.end() + } + fn serialize_newtype_struct<T: ?Sized + Serialize>( + self, + _name: &'static str, + _value: &T, + ) -> Result<Self::Ok, Self::Error> { + expected_enum() + } + fn serialize_newtype_variant<T: ?Sized + Serialize>( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<Self::Ok, Self::Error> { + let mut seq = self.0.serialize_seq(Some(2))?; + seq.serialize_element(variant)?; + seq.serialize_element(value)?; + seq.end() + } + fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> { + expected_enum() + } + fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> { + expected_enum() + } + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleStruct, Self::Error> { + expected_enum() + } + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleVariant, Self::Error> { + let mut seq = self.0.serialize_seq(Some(1 + len))?; + seq.serialize_element(variant)?; + Ok(EnumAsArray(seq)) + } + fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> { + expected_enum() + } + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result<Self::SerializeStruct, Self::Error> { + expected_enum() + } + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeStructVariant, Self::Error> { + Err(ser::Error::custom( + "enum_as_array::serialize does not support struct variants", + )) + } +} + +impl<S> SerializeTupleVariant for EnumAsArray<S> +where + S: SerializeSeq, +{ + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field<T: ?Sized + Serialize>(&mut self, value: &T) -> Result<(), Self::Error> { + self.0.serialize_element(value) + } + fn end(self) -> Result<Self::Ok, Self::Error> { + self.0.end() + } +} + +impl<'de, D> Deserializer<'de> for EnumAsArray<D> +where + D: Deserializer<'de>, +{ + type Error = D::Error; + + fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(de::Error::custom( + "EnumAsArray expected to operate on an enum type", + )) + } + + fn deserialize_enum<V>( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + self.0.deserialize_seq(EnumAsArray(visitor)) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct identifier ignored_any + } +} + +impl<'de, V> Visitor<'de> for EnumAsArray<V> +where + V: Visitor<'de>, +{ + type Value = V::Value; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + self.0.expecting(formatter) + } + + fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error> + where + A: SeqAccess<'de>, + { + self.0.visit_enum(EnumAsArray(seq)) + } +} + +impl<'de, A> EnumAccess<'de> for EnumAsArray<A> +where + A: SeqAccess<'de>, +{ + type Error = A::Error; + type Variant = Self; + + fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: DeserializeSeed<'de>, + { + match self.0.next_element_seed(seed)? { + Some(first) => Ok((first, self)), + None => Err(de::Error::custom("expected at least one element")), + } + } +} + +impl<'de, A> VariantAccess<'de> for EnumAsArray<A> +where + A: SeqAccess<'de>, +{ + type Error = A::Error; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed<T>(mut self, seed: T) -> Result<T::Value, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.0.next_element_seed(seed)? { + Some(newtype) => Ok(newtype), + None => Err(de::Error::custom("missing newtype variant value")), + } + } + + fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + visitor.visit_seq(self.0) + } + + fn struct_variant<V>( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result<V::Value, Self::Error> + where + V: Visitor<'de>, + { + Err(de::Error::custom( + "EnumAsArray does not support struct variants", + )) + } +} |