lib.rs - source (original) (raw)
pyo3_asyncio_macros/
lib.rs
1#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)]
2#![deny(missing_debug_implementations, nonstandard_style)]
3#![recursion_limit = "512"]
4
5mod tokio;
6
7use proc_macro::TokenStream;
8use quote::{quote, quote_spanned};
9use syn::spanned::Spanned;
10
11/// Enables an async main function that uses the async-std runtime.
12///
13/// # Examples
14///
15/// ```ignore
16/// #[pyo3_asyncio::async_std::main]
17/// async fn main() -> PyResult<()> {
18/// Ok(())
19/// }
20/// ```
21#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
22#[proc_macro_attribute]
23pub fn async_std_main(_attr: TokenStream, item: TokenStream) -> TokenStream {
24 let input = syn::parse_macro_input!(item as syn::ItemFn);
25
26 let ret = &input.sig.output;
27 let inputs = &input.sig.inputs;
28 let name = &input.sig.ident;
29 let body = &input.block;
30 let attrs = &input.attrs;
31 let vis = &input.vis;
32
33 if name != "main" {
34 return TokenStream::from(quote_spanned! { name.span() =>
35 compile_error!("only the main function can be tagged with #[async_std::main]"),
36 });
37 }
38
39 if input.sig.asyncness.is_none() {
40 return TokenStream::from(quote_spanned! { input.span() =>
41 compile_error!("the async keyword is missing from the function declaration"),
42 });
43 }
44
45 let result = quote! {
46 #vis fn main() {
47 #(#attrs)*
48 async fn main(#inputs) #ret {
49 #body
50 }
51
52 pyo3::prepare_freethreaded_python();
53
54 pyo3::Python::with_gil(|py| {
55 pyo3_asyncio::async_std::run(py, main())
56 .map_err(|e| {
57 e.print_and_set_sys_last_vars(py);
58 })
59 .unwrap();
60 });
61 }
62 };
63
64 result.into()
65}
66
67/// Enables an async main function that uses the tokio runtime.
68///
69/// # Arguments
70/// * `flavor` - selects the type of tokio runtime ["multi_thread", "current_thread"]
71/// * `worker_threads` - number of worker threads, defaults to the number of CPUs on the system
72///
73/// # Examples
74///
75/// Default configuration:
76/// ```ignore
77/// #[pyo3_asyncio::tokio::main]
78/// async fn main() -> PyResult<()> {
79/// Ok(())
80/// }
81/// ```
82///
83/// Current-thread scheduler:
84/// ```ignore
85/// #[pyo3_asyncio::tokio::main(flavor = "current_thread")]
86/// async fn main() -> PyResult<()> {
87/// Ok(())
88/// }
89/// ```
90///
91/// Multi-thread scheduler with custom worker thread count:
92/// ```ignore
93/// #[pyo3_asyncio::tokio::main(flavor = "multi_thread", worker_threads = 10)]
94/// async fn main() -> PyResult<()> {
95/// Ok(())
96/// }
97/// ```
98#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
99#[proc_macro_attribute]
100pub fn tokio_main(args: TokenStream, item: TokenStream) -> TokenStream {
101 tokio::main(args, item, true)
102}
103
104/// Registers an `async-std` test with the `pyo3-asyncio` test harness.
105///
106/// This attribute is meant to mirror the `#[test]` attribute and allow you to mark a function for
107/// testing within an integration test. Like the `#[async_std::test]` attribute, it will accept
108/// `async` test functions, but it will also accept blocking functions as well.
109///
110/// # Examples
111/// ```ignore
112/// use std::{time::Duration, thread};
113///
114/// use pyo3::prelude::*;
115///
116/// // async test function
117/// #[pyo3_asyncio::async_std::test]
118/// async fn test_async_sleep() -> PyResult<()> {
119/// async_std::task::sleep(Duration::from_secs(1)).await;
120/// Ok(())
121/// }
122///
123/// // blocking test function
124/// #[pyo3_asyncio::async_std::test]
125/// fn test_blocking_sleep() -> PyResult<()> {
126/// thread::sleep(Duration::from_secs(1));
127/// Ok(())
128/// }
129///
130/// // blocking test functions can optionally accept an event_loop parameter
131/// #[pyo3_asyncio::async_std::test]
132/// fn test_blocking_sleep_with_event_loop(event_loop: PyObject) -> PyResult<()> {
133/// thread::sleep(Duration::from_secs(1));
134/// Ok(())
135/// }
136/// ```
137#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
138#[proc_macro_attribute]
139pub fn async_std_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
140 let input = syn::parse_macro_input!(item as syn::ItemFn);
141
142 let sig = &input.sig;
143 let name = &input.sig.ident;
144 let body = &input.block;
145 let vis = &input.vis;
146
147 let fn_impl = if input.sig.asyncness.is_none() {
148 // Optionally pass an event_loop parameter to blocking tasks
149 let task = if sig.inputs.is_empty() {
150 quote! {
151 Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
152 #name()
153 }))
154 }
155 } else {
156 quote! {
157 let event_loop = Python::with_gil(|py| {
158 pyo3_asyncio::async_std::get_current_loop(py).unwrap().into()
159 });
160 Box::pin(pyo3_asyncio::async_std::re_exports::spawn_blocking(move || {
161 #name(event_loop)
162 }))
163 }
164 };
165
166 quote! {
167 #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
168 #sig {
169 #body
170 }
171
172 #task
173 }
174 }
175 } else {
176 quote! {
177 #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
178 #sig {
179 #body
180 }
181
182 Box::pin(#name())
183 }
184 }
185 };
186
187 let result = quote! {
188 #fn_impl
189
190 pyo3_asyncio::inventory::submit! {
191 pyo3_asyncio::testing::Test {
192 name: concat!(std::module_path!(), "::", stringify!(#name)),
193 test_fn: &#name
194 }
195 }
196 };
197
198 result.into()
199}
200
201/// Registers a `tokio` test with the `pyo3-asyncio` test harness.
202///
203/// This attribute is meant to mirror the `#[test]` attribute and allow you to mark a function for
204/// testing within an integration test. Like the `#[tokio::test]` attribute, it will accept `async`
205/// test functions, but it will also accept blocking functions as well.
206///
207/// # Examples
208/// ```ignore
209/// use std::{time::Duration, thread};
210///
211/// use pyo3::prelude::*;
212///
213/// // async test function
214/// #[pyo3_asyncio::tokio::test]
215/// async fn test_async_sleep() -> PyResult<()> {
216/// tokio::time::sleep(Duration::from_secs(1)).await;
217/// Ok(())
218/// }
219///
220/// // blocking test function
221/// #[pyo3_asyncio::tokio::test]
222/// fn test_blocking_sleep() -> PyResult<()> {
223/// thread::sleep(Duration::from_secs(1));
224/// Ok(())
225/// }
226///
227/// // blocking test functions can optionally accept an event_loop parameter
228/// #[pyo3_asyncio::tokio::test]
229/// fn test_blocking_sleep_with_event_loop(event_loop: PyObject) -> PyResult<()> {
230/// thread::sleep(Duration::from_secs(1));
231/// Ok(())
232/// }
233/// ```
234#[cfg(not(test))] // NOTE: exporting main breaks tests, we should file an issue.
235#[proc_macro_attribute]
236pub fn tokio_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
237 let input = syn::parse_macro_input!(item as syn::ItemFn);
238
239 let sig = &input.sig;
240 let name = &input.sig.ident;
241 let body = &input.block;
242 let vis = &input.vis;
243
244 let fn_impl = if input.sig.asyncness.is_none() {
245 // Optionally pass an event_loop parameter to blocking tasks
246 let task = if sig.inputs.is_empty() {
247 quote! {
248 Box::pin(async move {
249 match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name()).await {
250 Ok(result) => result,
251 Err(e) => {
252 assert!(e.is_panic());
253 let panic = e.into_panic();
254 let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
255 s.to_string()
256 } else if let Some(s) = panic.downcast_ref::<String>() {
257 s.clone()
258 } else {
259 "unknown error".into()
260 };
261 Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
262 }
263 }
264 })
265 }
266 } else {
267 quote! {
268 let event_loop = Python::with_gil(|py| {
269 pyo3_asyncio::tokio::get_current_loop(py).unwrap().into()
270 });
271 Box::pin(async move {
272 match pyo3_asyncio::tokio::get_runtime().spawn_blocking(move || #name(event_loop)).await {
273 Ok(result) => result,
274 Err(e) => {
275 assert!(e.is_panic());
276 let panic = e.into_panic();
277 let panic_message = if let Some(s) = panic.downcast_ref::<&str>() {
278 s.to_string()
279 } else if let Some(s) = panic.downcast_ref::<String>() {
280 s.clone()
281 } else {
282 "unknown error".into()
283 };
284 Err(pyo3_asyncio::err::RustPanic::new_err(format!("rust future panicked: {}", panic_message)))
285 }
286 }
287 })
288 }
289 };
290
291 quote! {
292 #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
293 #sig {
294 #body
295 }
296
297 #task
298 }
299 }
300 } else {
301 quote! {
302 #vis fn #name() -> std::pin::Pin<Box<dyn std::future::Future<Output = pyo3::PyResult<()>> + Send>> {
303 #sig {
304 #body
305 }
306
307 Box::pin(#name())
308 }
309 }
310 };
311
312 let result = quote! {
313 #fn_impl
314
315 pyo3_asyncio::inventory::submit! {
316 pyo3_asyncio::testing::Test {
317 name: concat!(std::module_path!(), "::", stringify!(#name)),
318 test_fn: &#name
319 }
320 }
321 };
322
323 result.into()
324}