]> git.lizzy.rs Git - rust.git/blob - crates/rust-analyzer/tests/heavy_tests/support.rs
Be more explicit about absolute paths at various places
[rust.git] / crates / rust-analyzer / tests / heavy_tests / support.rs
1 use std::{
2     cell::{Cell, RefCell},
3     fs,
4     path::{Path, PathBuf},
5     sync::Once,
6     time::Duration,
7 };
8
9 use crossbeam_channel::{after, select, Receiver};
10 use lsp_server::{Connection, Message, Notification, Request};
11 use lsp_types::{
12     notification::Exit, request::Shutdown, TextDocumentIdentifier, Url, WorkDoneProgress,
13 };
14 use lsp_types::{ProgressParams, ProgressParamsValue};
15 use serde::Serialize;
16 use serde_json::{to_string_pretty, Value};
17 use tempfile::TempDir;
18 use test_utils::{find_mismatch, Fixture};
19
20 use ra_db::AbsPathBuf;
21 use ra_project_model::ProjectManifest;
22 use rust_analyzer::{
23     config::{ClientCapsConfig, Config, FilesConfig, FilesWatcher, LinkedProject},
24     main_loop,
25 };
26
27 pub struct Project<'a> {
28     fixture: &'a str,
29     with_sysroot: bool,
30     tmp_dir: Option<TempDir>,
31     roots: Vec<PathBuf>,
32     config: Option<Box<dyn Fn(&mut Config)>>,
33 }
34
35 impl<'a> Project<'a> {
36     pub fn with_fixture(fixture: &str) -> Project {
37         Project { fixture, tmp_dir: None, roots: vec![], with_sysroot: false, config: None }
38     }
39
40     pub fn tmp_dir(mut self, tmp_dir: TempDir) -> Project<'a> {
41         self.tmp_dir = Some(tmp_dir);
42         self
43     }
44
45     pub(crate) fn root(mut self, path: &str) -> Project<'a> {
46         self.roots.push(path.into());
47         self
48     }
49
50     pub fn with_sysroot(mut self, sysroot: bool) -> Project<'a> {
51         self.with_sysroot = sysroot;
52         self
53     }
54
55     pub fn with_config(mut self, config: impl Fn(&mut Config) + 'static) -> Project<'a> {
56         self.config = Some(Box::new(config));
57         self
58     }
59
60     pub fn server(self) -> Server {
61         let tmp_dir = self.tmp_dir.unwrap_or_else(|| TempDir::new().unwrap());
62         static INIT: Once = Once::new();
63         INIT.call_once(|| {
64             env_logger::builder().is_test(true).try_init().unwrap();
65             ra_prof::init_from(crate::PROFILE);
66         });
67
68         for entry in Fixture::parse(self.fixture) {
69             let path = tmp_dir.path().join(&entry.path['/'.len_utf8()..]);
70             fs::create_dir_all(path.parent().unwrap()).unwrap();
71             fs::write(path.as_path(), entry.text.as_bytes()).unwrap();
72         }
73
74         let tmp_dir_path = AbsPathBuf::assert(tmp_dir.path().to_path_buf());
75         let mut roots =
76             self.roots.into_iter().map(|root| tmp_dir_path.join(root)).collect::<Vec<_>>();
77         if roots.is_empty() {
78             roots.push(tmp_dir_path.clone());
79         }
80         let linked_projects = roots
81             .into_iter()
82             .map(|it| ProjectManifest::discover_single(&it).unwrap())
83             .map(LinkedProject::from)
84             .collect::<Vec<_>>();
85
86         let mut config = Config {
87             client_caps: ClientCapsConfig {
88                 location_link: true,
89                 code_action_literals: true,
90                 work_done_progress: true,
91                 ..Default::default()
92             },
93             with_sysroot: self.with_sysroot,
94             linked_projects,
95             files: FilesConfig { watcher: FilesWatcher::Client, exclude: Vec::new() },
96             ..Config::new(tmp_dir_path)
97         };
98         if let Some(f) = &self.config {
99             f(&mut config)
100         }
101
102         Server::new(tmp_dir, config)
103     }
104 }
105
106 pub fn project(fixture: &str) -> Server {
107     Project::with_fixture(fixture).server()
108 }
109
110 pub struct Server {
111     req_id: Cell<u64>,
112     messages: RefCell<Vec<Message>>,
113     _thread: jod_thread::JoinHandle<()>,
114     client: Connection,
115     /// XXX: remove the tempdir last
116     dir: TempDir,
117 }
118
119 impl Server {
120     fn new(dir: TempDir, config: Config) -> Server {
121         let (connection, client) = Connection::memory();
122
123         let _thread = jod_thread::Builder::new()
124             .name("test server".to_string())
125             .spawn(move || main_loop(config, connection).unwrap())
126             .expect("failed to spawn a thread");
127
128         Server { req_id: Cell::new(1), dir, messages: Default::default(), client, _thread }
129     }
130
131     pub fn doc_id(&self, rel_path: &str) -> TextDocumentIdentifier {
132         let path = self.dir.path().join(rel_path);
133         TextDocumentIdentifier { uri: Url::from_file_path(path).unwrap() }
134     }
135
136     pub fn notification<N>(&self, params: N::Params)
137     where
138         N: lsp_types::notification::Notification,
139         N::Params: Serialize,
140     {
141         let r = Notification::new(N::METHOD.to_string(), params);
142         self.send_notification(r)
143     }
144
145     pub fn request<R>(&self, params: R::Params, expected_resp: Value)
146     where
147         R: lsp_types::request::Request,
148         R::Params: Serialize,
149     {
150         let actual = self.send_request::<R>(params);
151         if let Some((expected_part, actual_part)) = find_mismatch(&expected_resp, &actual) {
152             panic!(
153                 "JSON mismatch\nExpected:\n{}\nWas:\n{}\nExpected part:\n{}\nActual part:\n{}\n",
154                 to_string_pretty(&expected_resp).unwrap(),
155                 to_string_pretty(&actual).unwrap(),
156                 to_string_pretty(expected_part).unwrap(),
157                 to_string_pretty(actual_part).unwrap(),
158             );
159         }
160     }
161
162     pub fn send_request<R>(&self, params: R::Params) -> Value
163     where
164         R: lsp_types::request::Request,
165         R::Params: Serialize,
166     {
167         let id = self.req_id.get();
168         self.req_id.set(id + 1);
169
170         let r = Request::new(id.into(), R::METHOD.to_string(), params);
171         self.send_request_(r)
172     }
173     fn send_request_(&self, r: Request) -> Value {
174         let id = r.id.clone();
175         self.client.sender.send(r.into()).unwrap();
176         while let Some(msg) = self.recv() {
177             match msg {
178                 Message::Request(req) => {
179                     if req.method != "window/workDoneProgress/create"
180                         && !(req.method == "client/registerCapability"
181                             && req.params.to_string().contains("workspace/didChangeWatchedFiles"))
182                     {
183                         panic!("unexpected request: {:?}", req)
184                     }
185                 }
186                 Message::Notification(_) => (),
187                 Message::Response(res) => {
188                     assert_eq!(res.id, id);
189                     if let Some(err) = res.error {
190                         panic!("error response: {:#?}", err);
191                     }
192                     return res.result.unwrap();
193                 }
194             }
195         }
196         panic!("no response");
197     }
198     pub fn wait_until_workspace_is_loaded(&self) {
199         self.wait_for_message_cond(1, &|msg: &Message| match msg {
200             Message::Notification(n) if n.method == "$/progress" => {
201                 match n.clone().extract::<ProgressParams>("$/progress").unwrap() {
202                     ProgressParams {
203                         token: lsp_types::ProgressToken::String(ref token),
204                         value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(_)),
205                     } if token == "rustAnalyzer/roots scanned" => true,
206                     _ => false,
207                 }
208             }
209             _ => false,
210         })
211     }
212     fn wait_for_message_cond(&self, n: usize, cond: &dyn Fn(&Message) -> bool) {
213         let mut total = 0;
214         for msg in self.messages.borrow().iter() {
215             if cond(msg) {
216                 total += 1
217             }
218         }
219         while total < n {
220             let msg = self.recv().expect("no response");
221             if cond(&msg) {
222                 total += 1;
223             }
224         }
225     }
226     fn recv(&self) -> Option<Message> {
227         recv_timeout(&self.client.receiver).map(|msg| {
228             self.messages.borrow_mut().push(msg.clone());
229             msg
230         })
231     }
232     fn send_notification(&self, not: Notification) {
233         self.client.sender.send(Message::Notification(not)).unwrap();
234     }
235
236     pub fn path(&self) -> &Path {
237         self.dir.path()
238     }
239 }
240
241 impl Drop for Server {
242     fn drop(&mut self) {
243         self.request::<Shutdown>((), Value::Null);
244         self.notification::<Exit>(());
245     }
246 }
247
248 fn recv_timeout(receiver: &Receiver<Message>) -> Option<Message> {
249     let timeout = Duration::from_secs(120);
250     select! {
251         recv(receiver) -> msg => msg.ok(),
252         recv(after(timeout)) -> _ => panic!("timed out"),
253     }
254 }