]> git.lizzy.rs Git - rust.git/blob - crates/ra_proc_macro/src/process.rs
proc_macro: add ability to log to stderr and view output in vscode
[rust.git] / crates / ra_proc_macro / src / process.rs
1 //! Handle process life-time and message passing for proc-macro client
2
3 use crossbeam_channel::{bounded, Receiver, Sender};
4 use ra_tt::Subtree;
5
6 use crate::msg::{ErrorCode, Message, Request, Response, ResponseError};
7 use crate::rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind};
8
9 use io::{BufRead, BufReader};
10 use std::{
11     convert::{TryFrom, TryInto},
12     ffi::{OsStr, OsString},
13     io::{self, Write},
14     path::{Path, PathBuf},
15     process::{Child, Command, Stdio},
16     sync::{Arc, Weak},
17 };
18
19 #[derive(Debug, Default)]
20 pub(crate) struct ProcMacroProcessSrv {
21     inner: Option<Weak<Sender<Task>>>,
22 }
23
24 #[derive(Debug)]
25 pub(crate) struct ProcMacroProcessThread {
26     // XXX: drop order is significant
27     sender: Arc<Sender<Task>>,
28     handle: jod_thread::JoinHandle<()>,
29 }
30
31 impl ProcMacroProcessSrv {
32     pub fn run(
33         process_path: PathBuf,
34         args: impl IntoIterator<Item = impl AsRef<OsStr>>,
35     ) -> io::Result<(ProcMacroProcessThread, ProcMacroProcessSrv)> {
36         let process = Process::run(process_path, args)?;
37
38         let (task_tx, task_rx) = bounded(0);
39         let handle = jod_thread::spawn(move || {
40             client_loop(task_rx, process);
41         });
42
43         let task_tx = Arc::new(task_tx);
44         let srv = ProcMacroProcessSrv { inner: Some(Arc::downgrade(&task_tx)) };
45         let thread = ProcMacroProcessThread { handle, sender: task_tx };
46
47         Ok((thread, srv))
48     }
49
50     pub fn find_proc_macros(
51         &self,
52         dylib_path: &Path,
53     ) -> Result<Vec<(String, ProcMacroKind)>, ra_tt::ExpansionError> {
54         let task = ListMacrosTask { lib: dylib_path.to_path_buf() };
55
56         let result: ListMacrosResult = self.send_task(Request::ListMacro(task))?;
57         Ok(result.macros)
58     }
59
60     pub fn custom_derive(
61         &self,
62         dylib_path: &Path,
63         subtree: &Subtree,
64         derive_name: &str,
65     ) -> Result<Subtree, ra_tt::ExpansionError> {
66         let task = ExpansionTask {
67             macro_body: subtree.clone(),
68             macro_name: derive_name.to_string(),
69             attributes: None,
70             lib: dylib_path.to_path_buf(),
71         };
72
73         let result: ExpansionResult = self.send_task(Request::ExpansionMacro(task))?;
74         Ok(result.expansion)
75     }
76
77     pub fn send_task<R>(&self, req: Request) -> Result<R, ra_tt::ExpansionError>
78     where
79         R: TryFrom<Response, Error = &'static str>,
80     {
81         let sender = match &self.inner {
82             None => return Err(ra_tt::ExpansionError::Unknown("No sender is found.".to_string())),
83             Some(it) => it,
84         };
85
86         let (result_tx, result_rx) = bounded(0);
87         let sender = match sender.upgrade() {
88             None => {
89                 return Err(ra_tt::ExpansionError::Unknown("Proc macro process is closed.".into()))
90             }
91             Some(it) => it,
92         };
93         sender.send(Task { req: req.into(), result_tx }).unwrap();
94         let res = result_rx
95             .recv()
96             .map_err(|_| ra_tt::ExpansionError::Unknown("Proc macro thread is closed.".into()))?;
97
98         match res {
99             Some(Response::Error(err)) => {
100                 return Err(ra_tt::ExpansionError::ExpansionError(err.message));
101             }
102             Some(res) => Ok(res.try_into().map_err(|err| {
103                 ra_tt::ExpansionError::Unknown(format!(
104                     "Fail to get response, reason : {:#?} ",
105                     err
106                 ))
107             })?),
108             None => Err(ra_tt::ExpansionError::Unknown("Empty result".into())),
109         }
110     }
111 }
112
113 fn client_loop(task_rx: Receiver<Task>, mut process: Process) {
114     let (mut stdin, mut stdout) = match process.stdio() {
115         None => return,
116         Some(it) => it,
117     };
118
119     for task in task_rx {
120         let Task { req, result_tx } = task;
121
122         match send_request(&mut stdin, &mut stdout, req) {
123             Ok(res) => result_tx.send(res).unwrap(),
124             Err(_err) => {
125                 let res = Response::Error(ResponseError {
126                     code: ErrorCode::ServerErrorEnd,
127                     message: "Server closed".into(),
128                 });
129                 result_tx.send(res.into()).unwrap();
130                 // Restart the process
131                 if process.restart().is_err() {
132                     break;
133                 }
134                 let stdio = match process.stdio() {
135                     None => break,
136                     Some(it) => it,
137                 };
138                 stdin = stdio.0;
139                 stdout = stdio.1;
140             }
141         }
142     }
143 }
144
145 struct Task {
146     req: Request,
147     result_tx: Sender<Option<Response>>,
148 }
149
150 struct Process {
151     path: PathBuf,
152     args: Vec<OsString>,
153     child: Child,
154 }
155
156 impl Drop for Process {
157     fn drop(&mut self) {
158         let _ = self.child.kill();
159     }
160 }
161
162 impl Process {
163     fn run(
164         path: PathBuf,
165         args: impl IntoIterator<Item = impl AsRef<OsStr>>,
166     ) -> io::Result<Process> {
167         let args = args.into_iter().map(|s| s.as_ref().into()).collect();
168         let child = mk_child(&path, &args)?;
169         Ok(Process { path, args, child })
170     }
171
172     fn restart(&mut self) -> io::Result<()> {
173         let _ = self.child.kill();
174         self.child = mk_child(&self.path, &self.args)?;
175         Ok(())
176     }
177
178     fn stdio(&mut self) -> Option<(impl Write, impl BufRead)> {
179         let stdin = self.child.stdin.take()?;
180         let stdout = self.child.stdout.take()?;
181         let read = BufReader::new(stdout);
182
183         Some((stdin, read))
184     }
185 }
186
187 fn mk_child(path: &Path, args: impl IntoIterator<Item = impl AsRef<OsStr>>) -> io::Result<Child> {
188     Command::new(&path)
189         .args(args)
190         .stdin(Stdio::piped())
191         .stdout(Stdio::piped())
192         .stderr(Stdio::inherit())
193         .spawn()
194 }
195
196 fn send_request(
197     mut writer: &mut impl Write,
198     mut reader: &mut impl BufRead,
199     req: Request,
200 ) -> io::Result<Option<Response>> {
201     req.write(&mut writer)?;
202     Ok(Response::read(&mut reader)?)
203 }