]> git.lizzy.rs Git - rust.git/blobdiff - crates/proc_macro_api/src/process.rs
Remove proc macro management thread
[rust.git] / crates / proc_macro_api / src / process.rs
index 51ffcaa786d82b3ba0fb6a54eba18d663ac4ab5b..592c1282c0bce575400ac8236cb141f693cff8c4 100644 (file)
@@ -3,52 +3,48 @@
 use std::{
     convert::{TryFrom, TryInto},
     ffi::{OsStr, OsString},
+    fmt,
     io::{self, BufRead, BufReader, Write},
     path::{Path, PathBuf},
-    process::{Child, Command, Stdio},
-    sync::{Arc, Weak},
+    process::{Child, ChildStdin, ChildStdout, Command, Stdio},
+    sync::Mutex,
 };
 
-use crossbeam_channel::{bounded, Receiver, Sender};
-use tt::Subtree;
+use stdx::JodChild;
 
 use crate::{
     msg::{ErrorCode, Message, Request, Response, ResponseError},
-    rpc::{ExpansionResult, ExpansionTask, ListMacrosResult, ListMacrosTask, ProcMacroKind},
+    rpc::{ListMacrosResult, ListMacrosTask, ProcMacroKind},
 };
 
-#[derive(Debug, Default)]
 pub(crate) struct ProcMacroProcessSrv {
-    inner: Option<Weak<Sender<Task>>>,
+    process: Mutex<Process>,
+    stdio: Mutex<(ChildStdin, BufReader<ChildStdout>)>,
 }
 
-#[derive(Debug)]
-pub(crate) struct ProcMacroProcessThread {
-    // XXX: drop order is significant
-    sender: Arc<Sender<Task>>,
-    handle: jod_thread::JoinHandle<()>,
+impl fmt::Debug for ProcMacroProcessSrv {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("ProcMacroProcessSrv").field("process", &self.process).finish()
+    }
 }
 
 impl ProcMacroProcessSrv {
-    pub fn run(
+    pub(crate) fn run(
         process_path: PathBuf,
         args: impl IntoIterator<Item = impl AsRef<OsStr>>,
-    ) -> io::Result<(ProcMacroProcessThread, ProcMacroProcessSrv)> {
-        let process = Process::run(process_path, args)?;
-
-        let (task_tx, task_rx) = bounded(0);
-        let handle = jod_thread::spawn(move || {
-            client_loop(task_rx, process);
-        });
+    ) -> io::Result<ProcMacroProcessSrv> {
+        let mut process = Process::run(process_path, args)?;
+        let (stdin, stdout) = process.stdio().expect("couldn't access child stdio");
 
-        let task_tx = Arc::new(task_tx);
-        let srv = ProcMacroProcessSrv { inner: Some(Arc::downgrade(&task_tx)) };
-        let thread = ProcMacroProcessThread { handle, sender: task_tx };
+        let srv = ProcMacroProcessSrv {
+            process: Mutex::new(process),
+            stdio: Mutex::new((stdin, stdout)),
+        };
 
-        Ok((thread, srv))
+        Ok(srv)
     }
 
-    pub fn find_proc_macros(
+    pub(crate) fn find_proc_macros(
         &self,
         dylib_path: &Path,
     ) -> Result<Vec<(String, ProcMacroKind)>, tt::ExpansionError> {
@@ -58,48 +54,34 @@ pub fn find_proc_macros(
         Ok(result.macros)
     }
 
-    pub fn custom_derive(
-        &self,
-        dylib_path: &Path,
-        subtree: &Subtree,
-        derive_name: &str,
-    ) -> Result<Subtree, tt::ExpansionError> {
-        let task = ExpansionTask {
-            macro_body: subtree.clone(),
-            macro_name: derive_name.to_string(),
-            attributes: None,
-            lib: dylib_path.to_path_buf(),
-        };
-
-        let result: ExpansionResult = self.send_task(Request::ExpansionMacro(task))?;
-        Ok(result.expansion)
-    }
-
-    pub fn send_task<R>(&self, req: Request) -> Result<R, tt::ExpansionError>
+    pub(crate) fn send_task<R>(&self, req: Request) -> Result<R, tt::ExpansionError>
     where
         R: TryFrom<Response, Error = &'static str>,
     {
-        let sender = match &self.inner {
-            None => return Err(tt::ExpansionError::Unknown("No sender is found.".to_string())),
-            Some(it) => it,
-        };
-
-        let (result_tx, result_rx) = bounded(0);
-        let sender = match sender.upgrade() {
-            None => {
-                return Err(tt::ExpansionError::Unknown("Proc macro process is closed.".into()))
+        let mut guard = self.stdio.lock().unwrap_or_else(|e| e.into_inner());
+        let stdio = &mut *guard;
+        let (stdin, stdout) = (&mut stdio.0, &mut stdio.1);
+
+        let mut buf = String::new();
+        let res = match send_request(stdin, stdout, req, &mut buf) {
+            Ok(res) => res,
+            Err(err) => {
+                let mut process = self.process.lock().unwrap_or_else(|e| e.into_inner());
+                log::error!(
+                    "proc macro server crashed, server process state: {:?}, server request error: {:?}",
+                    process.child.try_wait(),
+                    err
+                );
+                let res = Response::Error(ResponseError {
+                    code: ErrorCode::ServerErrorEnd,
+                    message: "proc macro server crashed".into(),
+                });
+                Some(res)
             }
-            Some(it) => it,
         };
-        sender.send(Task { req, result_tx }).unwrap();
-        let res = result_rx
-            .recv()
-            .map_err(|_| tt::ExpansionError::Unknown("Proc macro thread is closed.".into()))?;
 
         match res {
-            Some(Response::Error(err)) => {
-                return Err(tt::ExpansionError::ExpansionError(err.message));
-            }
+            Some(Response::Error(err)) => Err(tt::ExpansionError::ExpansionError(err.message)),
             Some(res) => Ok(res.try_into().map_err(|err| {
                 tt::ExpansionError::Unknown(format!("Fail to get response, reason : {:#?} ", err))
             })?),
@@ -108,53 +90,9 @@ pub fn send_task<R>(&self, req: Request) -> Result<R, tt::ExpansionError>
     }
 }
 
-fn client_loop(task_rx: Receiver<Task>, mut process: Process) {
-    let (mut stdin, mut stdout) = match process.stdio() {
-        None => return,
-        Some(it) => it,
-    };
-
-    for task in task_rx {
-        let Task { req, result_tx } = task;
-
-        match send_request(&mut stdin, &mut stdout, req) {
-            Ok(res) => result_tx.send(res).unwrap(),
-            Err(_err) => {
-                let res = Response::Error(ResponseError {
-                    code: ErrorCode::ServerErrorEnd,
-                    message: "Server closed".into(),
-                });
-                result_tx.send(res.into()).unwrap();
-                // Restart the process
-                if process.restart().is_err() {
-                    break;
-                }
-                let stdio = match process.stdio() {
-                    None => break,
-                    Some(it) => it,
-                };
-                stdin = stdio.0;
-                stdout = stdio.1;
-            }
-        }
-    }
-}
-
-struct Task {
-    req: Request,
-    result_tx: Sender<Option<Response>>,
-}
-
+#[derive(Debug)]
 struct Process {
-    path: PathBuf,
-    args: Vec<OsString>,
-    child: Child,
-}
-
-impl Drop for Process {
-    fn drop(&mut self) {
-        let _ = self.child.kill();
-    }
+    child: JodChild,
 }
 
 impl Process {
@@ -162,18 +100,12 @@ fn run(
         path: PathBuf,
         args: impl IntoIterator<Item = impl AsRef<OsStr>>,
     ) -> io::Result<Process> {
-        let args = args.into_iter().map(|s| s.as_ref().into()).collect();
-        let child = mk_child(&path, &args)?;
-        Ok(Process { path, args, child })
-    }
-
-    fn restart(&mut self) -> io::Result<()> {
-        let _ = self.child.kill();
-        self.child = mk_child(&self.path, &self.args)?;
-        Ok(())
+        let args: Vec<OsString> = args.into_iter().map(|s| s.as_ref().into()).collect();
+        let child = JodChild(mk_child(&path, &args)?);
+        Ok(Process { child })
     }
 
-    fn stdio(&mut self) -> Option<(impl Write, impl BufRead)> {
+    fn stdio(&mut self) -> Option<(ChildStdin, BufReader<ChildStdout>)> {
         let stdin = self.child.stdin.take()?;
         let stdout = self.child.stdout.take()?;
         let read = BufReader::new(stdout);
@@ -195,7 +127,8 @@ fn send_request(
     mut writer: &mut impl Write,
     mut reader: &mut impl BufRead,
     req: Request,
+    buf: &mut String,
 ) -> io::Result<Option<Response>> {
     req.write(&mut writer)?;
-    Ok(Response::read(&mut reader)?)
+    Response::read(&mut reader, buf)
 }