]> git.lizzy.rs Git - rust.git/blob - crates/rust-analyzer/tests/heavy_tests/support.rs
Merge #4248
[rust.git] / crates / rust-analyzer / tests / heavy_tests / support.rs
1 use std::{
2     cell::{Cell, RefCell},
3     fs,
4     path::{Path, PathBuf},
5     sync::Once,
6     time::Duration,
7 };
8
9 use crossbeam_channel::{after, select, Receiver};
10 use lsp_server::{Connection, Message, Notification, Request};
11 use lsp_types::{
12     notification::{DidOpenTextDocument, Exit},
13     request::Shutdown,
14     DidOpenTextDocumentParams, TextDocumentIdentifier, TextDocumentItem, Url, WorkDoneProgress,
15 };
16 use serde::Serialize;
17 use serde_json::{to_string_pretty, Value};
18 use tempfile::TempDir;
19 use test_utils::{find_mismatch, parse_fixture};
20
21 use req::{ProgressParams, ProgressParamsValue};
22 use rust_analyzer::{
23     config::{ClientCapsConfig, Config},
24     main_loop, req,
25 };
26
27 pub struct Project<'a> {
28     fixture: &'a str,
29     with_sysroot: bool,
30     tmp_dir: Option<TempDir>,
31     roots: Vec<PathBuf>,
32     config: Option<Box<dyn Fn(&mut Config)>>,
33 }
34
35 impl<'a> Project<'a> {
36     pub fn with_fixture(fixture: &str) -> Project {
37         Project { fixture, tmp_dir: None, roots: vec![], with_sysroot: false, config: None }
38     }
39
40     pub fn tmp_dir(mut self, tmp_dir: TempDir) -> Project<'a> {
41         self.tmp_dir = Some(tmp_dir);
42         self
43     }
44
45     pub fn root(mut self, path: &str) -> Project<'a> {
46         self.roots.push(path.into());
47         self
48     }
49
50     pub fn with_sysroot(mut self, sysroot: bool) -> Project<'a> {
51         self.with_sysroot = sysroot;
52         self
53     }
54
55     pub fn with_config(mut self, config: impl Fn(&mut Config) + 'static) -> Project<'a> {
56         self.config = Some(Box::new(config));
57         self
58     }
59
60     pub fn server(self) -> Server {
61         let tmp_dir = self.tmp_dir.unwrap_or_else(|| TempDir::new().unwrap());
62         static INIT: Once = Once::new();
63         INIT.call_once(|| {
64             env_logger::builder().is_test(true).try_init().unwrap();
65             ra_prof::init_from(crate::PROFILE);
66         });
67
68         let mut paths = vec![];
69
70         for entry in parse_fixture(self.fixture) {
71             let path = tmp_dir.path().join(entry.meta);
72             fs::create_dir_all(path.parent().unwrap()).unwrap();
73             fs::write(path.as_path(), entry.text.as_bytes()).unwrap();
74             paths.push((path, entry.text));
75         }
76
77         let roots = self.roots.into_iter().map(|root| tmp_dir.path().join(root)).collect();
78
79         let mut config = Config {
80             client_caps: ClientCapsConfig {
81                 location_link: true,
82                 code_action_literals: true,
83                 ..Default::default()
84             },
85             with_sysroot: self.with_sysroot,
86             ..Config::default()
87         };
88
89         if let Some(f) = &self.config {
90             f(&mut config)
91         }
92
93         Server::new(tmp_dir, config, roots, paths)
94     }
95 }
96
97 pub fn project(fixture: &str) -> Server {
98     Project::with_fixture(fixture).server()
99 }
100
101 pub struct Server {
102     req_id: Cell<u64>,
103     messages: RefCell<Vec<Message>>,
104     _thread: jod_thread::JoinHandle<()>,
105     client: Connection,
106     /// XXX: remove the tempdir last
107     dir: TempDir,
108 }
109
110 impl Server {
111     fn new(
112         dir: TempDir,
113         config: Config,
114         roots: Vec<PathBuf>,
115         files: Vec<(PathBuf, String)>,
116     ) -> Server {
117         let path = dir.path().to_path_buf();
118
119         let roots = if roots.is_empty() { vec![path] } else { roots };
120         let (connection, client) = Connection::memory();
121
122         let _thread = jod_thread::Builder::new()
123             .name("test server".to_string())
124             .spawn(move || main_loop(roots, config, connection).unwrap())
125             .expect("failed to spawn a thread");
126
127         let res =
128             Server { req_id: Cell::new(1), dir, messages: Default::default(), client, _thread };
129
130         for (path, text) in files {
131             res.notification::<DidOpenTextDocument>(DidOpenTextDocumentParams {
132                 text_document: TextDocumentItem {
133                     uri: Url::from_file_path(path).unwrap(),
134                     language_id: "rust".to_string(),
135                     version: 0,
136                     text,
137                 },
138             })
139         }
140         res
141     }
142
143     pub fn doc_id(&self, rel_path: &str) -> TextDocumentIdentifier {
144         let path = self.dir.path().join(rel_path);
145         TextDocumentIdentifier { uri: Url::from_file_path(path).unwrap() }
146     }
147
148     pub fn notification<N>(&self, params: N::Params)
149     where
150         N: lsp_types::notification::Notification,
151         N::Params: Serialize,
152     {
153         let r = Notification::new(N::METHOD.to_string(), params);
154         self.send_notification(r)
155     }
156
157     pub fn request<R>(&self, params: R::Params, expected_resp: Value)
158     where
159         R: lsp_types::request::Request,
160         R::Params: Serialize,
161     {
162         let actual = self.send_request::<R>(params);
163         if let Some((expected_part, actual_part)) = find_mismatch(&expected_resp, &actual) {
164             panic!(
165                 "JSON mismatch\nExpected:\n{}\nWas:\n{}\nExpected part:\n{}\nActual part:\n{}\n",
166                 to_string_pretty(&expected_resp).unwrap(),
167                 to_string_pretty(&actual).unwrap(),
168                 to_string_pretty(expected_part).unwrap(),
169                 to_string_pretty(actual_part).unwrap(),
170             );
171         }
172     }
173
174     pub fn send_request<R>(&self, params: R::Params) -> Value
175     where
176         R: lsp_types::request::Request,
177         R::Params: Serialize,
178     {
179         let id = self.req_id.get();
180         self.req_id.set(id + 1);
181
182         let r = Request::new(id.into(), R::METHOD.to_string(), params);
183         self.send_request_(r)
184     }
185     fn send_request_(&self, r: Request) -> Value {
186         let id = r.id.clone();
187         self.client.sender.send(r.into()).unwrap();
188         while let Some(msg) = self.recv() {
189             match msg {
190                 Message::Request(req) if req.method == "window/workDoneProgress/create" => (),
191                 Message::Request(req) => panic!("unexpected request: {:?}", req),
192                 Message::Notification(_) => (),
193                 Message::Response(res) => {
194                     assert_eq!(res.id, id);
195                     if let Some(err) = res.error {
196                         panic!("error response: {:#?}", err);
197                     }
198                     return res.result.unwrap();
199                 }
200             }
201         }
202         panic!("no response");
203     }
204     pub fn wait_until_workspace_is_loaded(&self) {
205         self.wait_for_message_cond(1, &|msg: &Message| match msg {
206             Message::Notification(n) if n.method == "$/progress" => {
207                 match n.clone().extract::<ProgressParams>("$/progress").unwrap() {
208                     ProgressParams {
209                         token: req::ProgressToken::String(ref token),
210                         value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(_)),
211                     } if token == "rustAnalyzer/startup" => true,
212                     _ => false,
213                 }
214             }
215             _ => false,
216         })
217     }
218     fn wait_for_message_cond(&self, n: usize, cond: &dyn Fn(&Message) -> bool) {
219         let mut total = 0;
220         for msg in self.messages.borrow().iter() {
221             if cond(msg) {
222                 total += 1
223             }
224         }
225         while total < n {
226             let msg = self.recv().expect("no response");
227             if cond(&msg) {
228                 total += 1;
229             }
230         }
231     }
232     fn recv(&self) -> Option<Message> {
233         recv_timeout(&self.client.receiver).map(|msg| {
234             self.messages.borrow_mut().push(msg.clone());
235             msg
236         })
237     }
238     fn send_notification(&self, not: Notification) {
239         self.client.sender.send(Message::Notification(not)).unwrap();
240     }
241
242     pub fn path(&self) -> &Path {
243         self.dir.path()
244     }
245 }
246
247 impl Drop for Server {
248     fn drop(&mut self) {
249         self.request::<Shutdown>((), Value::Null);
250         self.notification::<Exit>(());
251     }
252 }
253
254 fn recv_timeout(receiver: &Receiver<Message>) -> Option<Message> {
255     let timeout = Duration::from_secs(120);
256     select! {
257         recv(receiver) -> msg => msg.ok(),
258         recv(after(timeout)) -> _ => panic!("timed out"),
259     }
260 }