Skip to content

Commit 1a68e3d

Browse files
committed
switch to multitask
Signed-off-by: Marc-Antoine Perennou <[email protected]>
1 parent fcc220f commit 1a68e3d

File tree

5 files changed

+121
-19
lines changed

5 files changed

+121
-19
lines changed

Cargo.toml

+9-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ default = [
3030
"futures-lite",
3131
"kv-log-macro",
3232
"log",
33+
"multitask",
3334
"num_cpus",
3435
"pin-project-lite",
35-
"smol",
3636
]
3737
docs = ["attributes", "unstable", "default"]
3838
unstable = [
@@ -57,7 +57,7 @@ alloc = [
5757
"futures-core/alloc",
5858
"pin-project-lite",
5959
]
60-
tokio02 = ["smol/tokio02"]
60+
tokio02 = ["tokio"]
6161

6262
[dependencies]
6363
async-attributes = { version = "1.1.1", optional = true }
@@ -83,7 +83,7 @@ surf = { version = "1.0.3", optional = true }
8383
async-io = { version = "0.1.5", optional = true }
8484
blocking = { version = "0.5.0", optional = true }
8585
futures-lite = { version = "0.1.8", optional = true }
86-
smol = { version = "0.1.17", optional = true }
86+
multitask = { version = "0.2.0", optional = true }
8787

8888
[target.'cfg(target_arch = "wasm32")'.dependencies]
8989
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
@@ -93,6 +93,12 @@ futures-channel = { version = "0.3.4", optional = true }
9393
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
9494
wasm-bindgen-test = "0.3.10"
9595

96+
[dependencies.tokio]
97+
version = "0.2"
98+
default-features = false
99+
features = ["rt-threaded"]
100+
optional = true
101+
96102
[dev-dependencies]
97103
femme = "1.3.0"
98104
rand = "0.7.3"

src/task/builder.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::task::{Context, Poll};
77
use pin_project_lite::pin_project;
88

99
use crate::io;
10-
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
10+
use crate::task::{self, JoinHandle, Task, TaskLocalsWrapper};
1111

1212
/// Task builder that configures the settings of a new task.
1313
#[derive(Debug, Default)]
@@ -61,9 +61,9 @@ impl Builder {
6161
});
6262

6363
let task = wrapped.tag.task().clone();
64-
let smol_task = smol::Task::spawn(wrapped).into();
64+
let handle = task::executor::spawn(wrapped);
6565

66-
Ok(JoinHandle::new(smol_task, task))
66+
Ok(JoinHandle::new(handle, task))
6767
}
6868

6969
/// Spawns a task locally with the configured settings.
@@ -81,9 +81,9 @@ impl Builder {
8181
});
8282

8383
let task = wrapped.tag.task().clone();
84-
let smol_task = smol::Task::local(wrapped).into();
84+
let handle = task::executor::local(wrapped);
8585

86-
Ok(JoinHandle::new(smol_task, task))
86+
Ok(JoinHandle::new(handle, task))
8787
}
8888

8989
/// Spawns a task locally with the configured settings.
@@ -166,8 +166,8 @@ impl Builder {
166166
unsafe {
167167
TaskLocalsWrapper::set_current(&wrapped.tag, || {
168168
let res = if should_run {
169-
// The first call should use run.
170-
smol::run(wrapped)
169+
// The first call should run the executor
170+
task::executor::run(wrapped)
171171
} else {
172172
futures_lite::future::block_on(wrapped)
173173
};

src/task/executor.rs

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
use std::cell::RefCell;
2+
use std::future::Future;
3+
use std::task::{Context, Poll};
4+
5+
static GLOBAL_EXECUTOR: once_cell::sync::Lazy<multitask::Executor> = once_cell::sync::Lazy::new(multitask::Executor::new);
6+
7+
struct Executor {
8+
local_executor: multitask::LocalExecutor,
9+
parker: async_io::parking::Parker,
10+
}
11+
12+
thread_local! {
13+
static EXECUTOR: RefCell<Executor> = RefCell::new({
14+
let (parker, unparker) = async_io::parking::pair();
15+
let local_executor = multitask::LocalExecutor::new(move || unparker.unpark());
16+
Executor { local_executor, parker }
17+
});
18+
}
19+
20+
pub(crate) fn spawn<F, T>(future: F) -> multitask::Task<T>
21+
where
22+
F: Future<Output = T> + Send + 'static,
23+
T: Send + 'static,
24+
{
25+
GLOBAL_EXECUTOR.spawn(future)
26+
}
27+
28+
#[cfg(feature = "unstable")]
29+
pub(crate) fn local<F, T>(future: F) -> multitask::Task<T>
30+
where
31+
F: Future<Output = T> + 'static,
32+
T: 'static,
33+
{
34+
EXECUTOR.with(|executor| executor.borrow().local_executor.spawn(future))
35+
}
36+
37+
pub(crate) fn run<F, T>(future: F) -> T
38+
where
39+
F: Future<Output = T>,
40+
{
41+
enter(|| EXECUTOR.with(|executor| {
42+
let executor = executor.borrow();
43+
let unparker = executor.parker.unparker();
44+
let global_ticker = GLOBAL_EXECUTOR.ticker(move || unparker.unpark());
45+
let unparker = executor.parker.unparker();
46+
let waker = async_task::waker_fn(move || unparker.unpark());
47+
let cx = &mut Context::from_waker(&waker);
48+
pin_utils::pin_mut!(future);
49+
loop {
50+
if let Poll::Ready(res) = future.as_mut().poll(cx) {
51+
return res;
52+
}
53+
if let Ok(false) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| executor.local_executor.tick() || global_ticker.tick())) {
54+
executor.parker.park();
55+
}
56+
}
57+
}))
58+
}
59+
60+
/// Enters the tokio context if the `tokio` feature is enabled.
61+
fn enter<T>(f: impl FnOnce() -> T) -> T {
62+
#[cfg(not(feature = "tokio02"))]
63+
return f();
64+
65+
#[cfg(feature = "tokio02")]
66+
{
67+
use std::cell::Cell;
68+
use tokio::runtime::Runtime;
69+
70+
thread_local! {
71+
/// The level of nested `enter` calls we are in, to ensure that the outermost always
72+
/// has a runtime spawned.
73+
static NESTING: Cell<usize> = Cell::new(0);
74+
}
75+
76+
/// The global tokio runtime.
77+
static RT: once_cell::sync::Lazy<Runtime> = once_cell::sync::Lazy::new(|| Runtime::new().expect("cannot initialize tokio"));
78+
79+
NESTING.with(|nesting| {
80+
let res = if nesting.get() == 0 {
81+
nesting.replace(1);
82+
RT.enter(f)
83+
} else {
84+
nesting.replace(nesting.get() + 1);
85+
f()
86+
};
87+
nesting.replace(nesting.get() - 1);
88+
res
89+
})
90+
}
91+
}

src/task/join_handle.rs

+12-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub struct JoinHandle<T> {
1818
}
1919

2020
#[cfg(not(target_os = "unknown"))]
21-
type InnerHandle<T> = async_task::JoinHandle<T, ()>;
21+
type InnerHandle<T> = multitask::Task<T>;
2222
#[cfg(target_arch = "wasm32")]
2323
type InnerHandle<T> = futures_channel::oneshot::Receiver<T>;
2424

@@ -54,8 +54,7 @@ impl<T> JoinHandle<T> {
5454
#[cfg(not(target_os = "unknown"))]
5555
pub async fn cancel(mut self) -> Option<T> {
5656
let handle = self.handle.take().unwrap();
57-
handle.cancel();
58-
handle.await
57+
handle.cancel().await
5958
}
6059

6160
/// Cancel this task.
@@ -67,15 +66,19 @@ impl<T> JoinHandle<T> {
6766
}
6867
}
6968

69+
#[cfg(not(target_os = "unknown"))]
70+
impl<T> Drop for JoinHandle<T> {
71+
fn drop(&mut self) {
72+
if let Some(handle) = self.handle.take() {
73+
handle.detach();
74+
}
75+
}
76+
}
77+
7078
impl<T> Future for JoinHandle<T> {
7179
type Output = T;
7280

7381
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74-
match Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) {
75-
Poll::Pending => Poll::Pending,
76-
Poll::Ready(output) => {
77-
Poll::Ready(output.expect("cannot await the result of a panicked task"))
78-
}
79-
}
82+
Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx)
8083
}
8184
}

src/task/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ cfg_default! {
148148
mod block_on;
149149
mod builder;
150150
mod current;
151+
#[cfg(not(target_os = "unknown"))]
152+
mod executor;
151153
mod join_handle;
152154
mod sleep;
153155
#[cfg(not(target_os = "unknown"))]

0 commit comments

Comments
 (0)