lib.rs - source (original) (raw)

pyo3_asyncio_macros/

lib.rs

1#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)]
2#![deny(missing_debug_implementations, nonstandard_style)]
3#![recursion_limit = "512"]
4
5mod tokio;
6
7use proc_macro::TokenStream;
8use quote::{quote, quote_spanned};
9use syn::spanned::Spanned;
10
11/// Enables an async main function that uses the async-std runtime.
12///
13/// # Examples
14///
15/// ```ignore
16/// #[pyo3_asyncio::async_std::main]
17/// async fn main() -> PyResult<()> {
18///     Ok(())
19/// }
20/// ```
21#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
22#[proc_macro_attribute]
23pub fn async_std_main(_attr: TokenStream, item: TokenStream) -> TokenStream {
24    let input = syn::parse_macro_input!(item as syn::ItemFn);
25
26    let ret = &input.sig.output;
27    let inputs = &input.sig.inputs;
28    let name = &input.sig.ident;
29    let body = &input.block;
30    let attrs = &input.attrs;
31    let vis = &input.vis;
32
33    if name != "main" {
34        return TokenStream::from(quote_spanned! { name.span() =>
35            compile_error!("only the main function can be tagged with #[async_std::main]"),
36        });
37    }
38
39    if input.sig.asyncness.is_none() {
40        return TokenStream::from(quote_spanned! { input.span() =>
41            compile_error!("the async keyword is missing from the function declaration"),
42        });
43    }
44
45    let result = quote! {
46        #vis fn main() {
47            #(#attrs)*
48            async fn main(#inputs) #ret {
49                #body
50            }
51
52            pyo3::prepare_freethreaded_python();
53
54            pyo3::Python::with_gil(|py| {
55                pyo3_asyncio::async_std::run(py, main())
56                    .map_err(|e| {
57                        e.print_and_set_sys_last_vars(py);
58                    })
59                    .unwrap();
60            });
61        }
62    };
63
64    result.into()
65}
66
67/// Enables an async main function that uses the tokio runtime.
68///
69/// # Arguments
70/// * `flavor` - selects the type of tokio runtime ["multi_thread", "current_thread"]
71/// * `worker_threads` - number of worker threads, defaults to the number of CPUs on the system
72///
73/// # Examples
74///
75/// Default configuration:
76/// ```ignore
77/// #[pyo3_asyncio::tokio::main]
78/// async fn main() -> PyResult<()> {
79///     Ok(())
80/// }
81/// ```
82///
83/// Current-thread scheduler:
84/// ```ignore
85/// #[pyo3_asyncio::tokio::main(flavor = "current_thread")]
86/// async fn main() -> PyResult<()> {
87///     Ok(())
88/// }
89/// ```
90///
91/// Multi-thread scheduler with custom worker thread count:
92/// ```ignore
93/// #[pyo3_asyncio::tokio::main(flavor = "multi_thread", worker_threads = 10)]
94/// async fn main() -> PyResult<()> {
95///     Ok(())
96/// }
97/// ```
98#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
99#[proc_macro_attribute]
100pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
101    tokio::main(args, item, true)
102}
103
104/// Registers an `async-std` test with the `pyo3-asyncio` test harness.
105///
106/// This attribute is meant to mirror the `#[test]` attribute and allow you to mark a function for
107/// testing within an integration test. Like the `#[async_std::test]` attribute, it will accept
108/// `async` test functions, but it will also accept blocking functions as well.
109///
110/// # Examples
111/// ```ignore
112/// use std::{time::Duration, thread};
113///
114/// use pyo3::prelude::*;
115///
116/// // async test function
117/// #[pyo3_asyncio::async_std::test]
118/// async fn test_async_sleep() -> PyResult<()> {
119///     async_std::task::sleep(Duration::from_secs(1)).await;
120///     Ok(())
121/// }
122///
123/// // blocking test function
124/// #[pyo3_asyncio::async_std::test]
125/// fn test_blocking_sleep() -> PyResult<()> {
126///     thread::sleep(Duration::from_secs(1));
127///     Ok(())
128/// }
129///
130/// // blocking test functions can optionally accept an event_loop parameter
131/// #[pyo3_asyncio::async_std::test]
132/// fn test_blocking_sleep_with_event_loop(event_loop: PyObject) -> PyResult<()> {
133///     thread::sleep(Duration::from_secs(1));
134///     Ok(())
135/// }
136/// ```
137#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
138#[proc_macro_attribute]
139pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
140    let input = syn::parse_macro_input!(item as syn::ItemFn);
141
142    let sig = &input.sig;
143    let name = &input.sig.ident;
144    let body = &input.block;
145    let vis = &input.vis;
146
147    let fn_impl = if input.sig.asyncness.is_none() {
148        // Optionally pass an event_loop parameter to blocking tasks
149        let task = if sig.inputs.is_empty() {
150            quote! {
151                Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
152                    #name()
153                }))
154            }
155        } else {
156            quote! {
157                let event_loop = Python::with_gil(|py| {
158                    pyo3_asyncio::async_std::get_current_loop(py).unwrap().into()
159                });
160                Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
161                    #name(event_loop)
162                }))
163            }
164        };
165
166        quote! {
167            #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
168                #sig {
169                    #body
170                }
171
172                #task
173            }
174        }
175    } else {
176        quote! {
177            #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
178                #sig {
179                    #body
180                }
181
182                Box::pin(#name())
183            }
184        }
185    };
186
187    let result = quote! {
188        #fn_impl
189
190        pyo3_asyncio::inventory::submit! {
191            pyo3_asyncio::testing::Test {
192                name: concat!(std::module_path!(), "::", stringify!(#name)),
193                test_fn: &#name
194            }
195        }
196    };
197
198    result.into()
199}
200
201/// Registers a `tokio` test with the `pyo3-asyncio` test harness.
202///
203/// This attribute is meant to mirror the `#[test]` attribute and allow you to mark a function for
204/// testing within an integration test. Like the `#[tokio::test]` attribute, it will accept `async`
205/// test functions, but it will also accept blocking functions as well.
206///
207/// # Examples
208/// ```ignore
209/// use std::{time::Duration, thread};
210///
211/// use pyo3::prelude::*;
212///
213/// // async test function
214/// #[pyo3_asyncio::tokio::test]
215/// async fn test_async_sleep() -> PyResult<()> {
216///     tokio::time::sleep(Duration::from_secs(1)).await;
217///     Ok(())
218/// }
219///
220/// // blocking test function
221/// #[pyo3_asyncio::tokio::test]
222/// fn test_blocking_sleep() -> PyResult<()> {
223///     thread::sleep(Duration::from_secs(1));
224///     Ok(())
225/// }
226///
227/// // blocking test functions can optionally accept an event_loop parameter
228/// #[pyo3_asyncio::tokio::test]
229/// fn test_blocking_sleep_with_event_loop(event_loop: PyObject) -> PyResult<()> {
230///     thread::sleep(Duration::from_secs(1));
231///     Ok(())
232/// }
233/// ```
234#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
235#[proc_macro_attribute]
236pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
237    let input = syn::parse_macro_input!(item as syn::ItemFn);
238
239    let sig = &input.sig;
240    let name = &input.sig.ident;
241    let body = &input.block;
242    let vis = &input.vis;
243
244    let fn_impl = if input.sig.asyncness.is_none() {
245        // Optionally pass an event_loop parameter to blocking tasks
246        let task = if sig.inputs.is_empty() {
247            quote! {
248                Box::pin(async move {
249                    match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name()).await {
250                        Ok(result) => result,
251                        Err(e) => {
252                            assert!(e.is_panic());
253                            let panic = e.into_panic();
254                            let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
255                                s.to_string()
256                            } else if let Some(s) = panic.downcast_ref::<String>() {
257                                s.clone()
258                            } else {
259                                "unknown error".into()
260                            };
261                            Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
262                        }
263                    }
264                })
265            }
266        } else {
267            quote! {
268                let event_loop = Python::with_gil(|py| {
269                    pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
270                });
271                Box::pin(async move {
272                    match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name(event_loop)).await {
273                        Ok(result) => result,
274                        Err(e) => {
275                            assert!(e.is_panic());
276                            let panic = e.into_panic();
277                            let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
278                                s.to_string()
279                            } else if let Some(s) = panic.downcast_ref::<String>() {
280                                s.clone()
281                            } else {
282                                "unknown error".into()
283                            };
284                            Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
285                        }
286                    }
287                })
288            }
289        };
290
291        quote! {
292            #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
293                #sig {
294                    #body
295                }
296
297                #task
298            }
299        }
300    } else {
301        quote! {
302            #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
303                #sig {
304                    #body
305                }
306
307                Box::pin(#name())
308            }
309        }
310    };
311
312    let result = quote! {
313        #fn_impl
314
315        pyo3_asyncio::inventory::submit! {
316            pyo3_asyncio::testing::Test {
317                name: concat!(std::module_path!(), "::", stringify!(#name)),
318                test_fn: &#name
319            }
320        }
321    };
322
323    result.into()
324}