]> git.lizzy.rs Git - rust.git/blob - crates/proc_macro_api/src/msg.rs
Merge #7068
[rust.git] / crates / proc_macro_api / src / msg.rs
1 //! Defines messages for cross-process message passing based on `ndjson` wire protocol
2
3 use std::{
4     convert::TryFrom,
5     io::{self, BufRead, Write},
6 };
7
8 use serde::{de::DeserializeOwned, Deserialize, Serialize};
9
10 use crate::{
11     rpc::{ListMacrosResult, ListMacrosTask},
12     ExpansionResult, ExpansionTask,
13 };
14
15 #[derive(Debug, Serialize, Deserialize, Clone)]
16 pub enum Request {
17     ListMacro(ListMacrosTask),
18     ExpansionMacro(ExpansionTask),
19 }
20
21 #[derive(Debug, Serialize, Deserialize, Clone)]
22 pub enum Response {
23     Error(ResponseError),
24     ListMacro(ListMacrosResult),
25     ExpansionMacro(ExpansionResult),
26 }
27
28 macro_rules! impl_try_from_response {
29     ($ty:ty, $tag:ident) => {
30         impl TryFrom<Response> for $ty {
31             type Error = &'static str;
32             fn try_from(value: Response) -> Result<Self, Self::Error> {
33                 match value {
34                     Response::$tag(res) => Ok(res),
35                     _ => Err(concat!("Failed to convert response to ", stringify!($tag))),
36                 }
37             }
38         }
39     };
40 }
41
42 impl_try_from_response!(ListMacrosResult, ListMacro);
43 impl_try_from_response!(ExpansionResult, ExpansionMacro);
44
45 #[derive(Debug, Serialize, Deserialize, Clone)]
46 pub struct ResponseError {
47     pub code: ErrorCode,
48     pub message: String,
49 }
50
51 #[derive(Debug, Serialize, Deserialize, Clone)]
52 pub enum ErrorCode {
53     ServerErrorEnd,
54     ExpansionError,
55 }
56
57 pub trait Message: Serialize + DeserializeOwned {
58     fn read(inp: &mut impl BufRead) -> io::Result<Option<Self>> {
59         Ok(match read_json(inp)? {
60             None => None,
61             Some(text) => {
62                 let mut deserializer = serde_json::Deserializer::from_str(&text);
63                 // Note that some proc-macro generate very deep syntax tree
64                 // We have to disable the current limit of serde here
65                 deserializer.disable_recursion_limit();
66                 Some(Self::deserialize(&mut deserializer)?)
67             }
68         })
69     }
70     fn write(self, out: &mut impl Write) -> io::Result<()> {
71         let text = serde_json::to_string(&self)?;
72         write_json(out, &text)
73     }
74 }
75
76 impl Message for Request {}
77 impl Message for Response {}
78
79 fn read_json(inp: &mut impl BufRead) -> io::Result<Option<String>> {
80     let mut buf = String::new();
81     inp.read_line(&mut buf)?;
82     buf.pop(); // Remove traling '\n'
83     Ok(match buf.len() {
84         0 => None,
85         _ => Some(buf),
86     })
87 }
88
89 fn write_json(out: &mut impl Write, msg: &str) -> io::Result<()> {
90     log::debug!("> {}", msg);
91     out.write_all(msg.as_bytes())?;
92     out.write_all(b"\n")?;
93     out.flush()?;
94     Ok(())
95 }