Skip to content

RFC: Implement a handshake for VSock #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions src/devices/src/virtio/vsock/csm/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ pub struct VsockConnection<S: Read + Write + AsRawFd> {
/// Instant when this connection should be scheduled for immediate termination, due to some
/// timeout condition having been fulfilled.
expiry: Option<Instant>,
/// If this true, Reply the connection status before transfer data or close this connection.
need_reply: bool,
}

impl<S> VsockChannel for VsockConnection<S>
Expand Down Expand Up @@ -304,9 +306,30 @@ where
// Next up: receiving a response / confirmation for a host-initiated connection.
// We'll move to an Established state, and pass on the good news through the host
// stream.
ConnState::LocalInit if pkt.op() == uapi::VSOCK_OP_RESPONSE => {
self.expiry = None;
self.state = ConnState::Established;
ConnState::LocalInit
if pkt.op() == uapi::VSOCK_OP_RESPONSE || pkt.op() == uapi::VSOCK_OP_RST =>
{
let is_response = pkt.op() == uapi::VSOCK_OP_RESPONSE;
if self.need_reply {
self.need_reply = false;
if let Err(err) = self.send_bytes(if is_response { b"101\n" } else { b"503\n" })
{
// If we can't write to the host stream, that's an unrecoverable error, so
// we'll terminate this connection.
warn!(
"vsock: error writing to local stream (lp={}, pp={}): {:?}",
self.local_port, self.peer_port, err
);
if is_response {
self.kill();
}
return Ok(());
}
}
if is_response {
self.expiry = None;
self.state = ConnState::Established;
}
}

// The peer wants to shut down an established connection. If they have nothing
Expand Down Expand Up @@ -478,6 +501,7 @@ where
last_fwd_cnt_to_peer: Wrapping(0),
pending_rx: PendingRxSet::from(PendingRx::Response),
expiry: None,
need_reply: false,
}
}

Expand All @@ -488,6 +512,7 @@ where
peer_cid: u64,
local_port: u32,
peer_port: u32,
need_reply: bool,
) -> Self {
Self {
local_cid,
Expand All @@ -504,6 +529,7 @@ where
last_fwd_cnt_to_peer: Wrapping(0),
pending_rx: PendingRxSet::from(PendingRx::Request),
expiry: None,
need_reply,
}
}

Expand Down Expand Up @@ -738,7 +764,7 @@ mod tests {
Self::new(ConnState::Established)
}

fn new(conn_state: ConnState) -> Self {
fn new_maybe_need_reply(conn_state: ConnState, need_reply: bool) -> Self {
let vsock_test_ctx = TestContext::new();
let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context();
let stream = TestStream::new();
Expand All @@ -756,7 +782,7 @@ mod tests {
PEER_BUF_ALLOC,
),
ConnState::LocalInit => VsockConnection::<TestStream>::new_local_init(
stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT,
stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT, need_reply,
),
ConnState::Established => {
let mut conn = VsockConnection::<TestStream>::new_peer_init(
Expand All @@ -782,6 +808,10 @@ mod tests {
}
}

fn new(conn_state: ConnState) -> Self {
Self::new_maybe_need_reply(conn_state, false)
}

fn set_stream(&mut self, stream: TestStream) {
self.conn.stream = stream;
}
Expand Down Expand Up @@ -826,6 +856,7 @@ mod tests {
fn test_peer_request() {
let mut ctx = CsmTestContext::new(ConnState::PeerInit);
assert!(ctx.conn.has_pending_rx());
assert_eq!(ctx.conn.need_reply, false);
ctx.recv();
// For peer-initiated requests, our connection should always yield a vsock reponse packet,
// in order to establish the connection.
Expand All @@ -844,6 +875,29 @@ mod tests {
#[test]
fn test_local_request() {
let mut ctx = CsmTestContext::new(ConnState::LocalInit);
assert_eq!(ctx.conn.need_reply, false);
// Host-initiated connections should first yield a connection request packet.
assert!(ctx.conn.has_pending_rx());
// Before yielding the connection request packet, the timeout kill timer shouldn't be
// armed.
assert!(!ctx.conn.will_expire());
ctx.recv();
assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST);
// Since the request might time-out, the kill timer should now be armed.
assert!(ctx.conn.will_expire());
assert!(!ctx.conn.has_expired());
ctx.init_pkt(uapi::VSOCK_OP_RESPONSE, 0);
ctx.send();
// Upon receiving a connection response, the connection should have transitioned to the
// established state, and the kill timer should've been disarmed.
assert_eq!(ctx.conn.state, ConnState::Established);
assert!(!ctx.conn.will_expire());
}

#[test]
fn test_local_request_need_reply() {
let mut ctx = CsmTestContext::new_maybe_need_reply(ConnState::LocalInit, true);
assert_eq!(ctx.conn.need_reply, true);
// Host-initiated connections should first yield a connection request packet.
assert!(ctx.conn.has_pending_rx());
// Before yielding the connection request packet, the timeout kill timer shouldn't be
Expand All @@ -865,6 +919,21 @@ mod tests {
#[test]
fn test_local_request_timeout() {
let mut ctx = CsmTestContext::new(ConnState::LocalInit);
assert_eq!(ctx.conn.need_reply, false);
ctx.recv();
assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST);
assert!(ctx.conn.will_expire());
assert!(!ctx.conn.has_expired());
std::thread::sleep(std::time::Duration::from_millis(
defs::CONN_REQUEST_TIMEOUT_MS,
));
assert!(ctx.conn.has_expired());
}

#[test]
fn test_local_request_timeout_need_reply() {
let mut ctx = CsmTestContext::new_maybe_need_reply(ConnState::LocalInit, true);
assert_eq!(ctx.conn.need_reply, true);
ctx.recv();
assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST);
assert!(ctx.conn.will_expire());
Expand Down
Loading