shadowquic/squic/
mod.rs

1//! This module is shared by sunnyquic and shadowquic
2//! It handles the general tcp/udp proxying logic over quic connection
3//! It contains an optional authentication feature for sunnyquic only
4
5use std::{
6    collections::{
7        HashMap,
8        hash_map::{self, Entry},
9    },
10    io::Cursor,
11    mem::replace,
12    ops::Deref,
13    sync::{Arc, atomic::AtomicU16},
14    time::Duration,
15};
16
17use bytes::{BufMut, Bytes, BytesMut};
18use tokio::{
19    io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
20    sync::{
21        RwLock, SetOnce,
22        watch::{Receiver, Sender, channel},
23    },
24};
25use tracing::{Instrument, Level, debug, error, event, info, trace};
26
27use crate::{
28    AnyUdpRecv, AnyUdpSend,
29    error::{SError, SResult},
30    msgs::squic::SunnyCredential,
31    msgs::{
32        SDecode, SEncode,
33        socks5::SocksAddr,
34        squic::{SQPacketDatagramHeader, SQReq, SQUdpControlHeader},
35    },
36    quic::QuicConnection,
37};
38
39pub mod inbound;
40pub mod outbound;
41
42/// SQuic connection, it is shared by shadowquic and sunnyquic and is a wrapper of quic connection.
43/// It contains a connection object and two ID store for managing UDP sockets.
44/// The IDStore stores the mapping between ids and the destionation addresses as well as associated sockets
45#[derive(Clone)]
46pub struct SQConn<T: QuicConnection> {
47    pub(crate) conn: T,
48    pub(crate) authed: Arc<SetOnce<bool>>,
49    pub(crate) send_id_store: IDStore<()>,
50    pub(crate) recv_id_store: IDStore<(AnyUdpSend, SocksAddr)>,
51}
52
53async fn wait_sunny_auth<T: QuicConnection>(conn: &SQConn<T>) -> SResult<()> {
54    match tokio::time::timeout(Duration::from_millis(3200), conn.authed.wait()).await {
55        Ok(true) => Ok(()),
56        Ok(false) => Err(SError::SunnyAuthError("Wrong psassword/username".into())),
57        Err(_) => Err(SError::SunnyAuthError("timeout".into())),
58    }
59}
60
61pub(crate) async fn auth_sunny<T: QuicConnection>(
62    conn: &SQConn<T>,
63    user_hash: SunnyCredential,
64) -> SResult<()> {
65    if conn.authed.get().is_none() {
66        let (mut send, _recv, _id) = conn.open_bi().await?;
67        SQReq::SQAuthenticate(user_hash).encode(&mut send).await?;
68        debug!("authentication request sent");
69        conn.authed.set(true).expect("repeated authentication");
70    }
71    Ok(())
72}
73
74impl<T: QuicConnection> Deref for SQConn<T> {
75    type Target = T;
76
77    fn deref(&self) -> &Self::Target {
78        &self.conn
79    }
80}
81
82// Use watch channel here. Notify is not suitable here
83// see https://github.com/tokio-rs/tokio/issues/3757
84type IDStoreVal<T> = Result<T, Sender<()>>;
85/// IDStore is a thread-safe store for managing UDP sockets and their associated ids.
86/// It uses a HashMap to store the mapping between ids and the destination addresses as well as associated sockets.
87/// It also uses an atomic counter to generate unique ids for new sockets.
88#[derive(Clone, Default)]
89pub(crate) struct IDStore<T = (AnyUdpSend, SocksAddr)> {
90    pub(crate) id_counter: Arc<AtomicU16>,
91    pub(crate) inner: Arc<RwLock<HashMap<u16, IDStoreVal<T>>>>,
92}
93
94impl<T> IDStore<T>
95where
96    T: Clone,
97{
98    async fn get_socket_or_notify(&self, id: u16) -> Result<T, Receiver<()>> {
99        if let Some(r) = self.inner.read().await.get(&id) {
100            r.clone().map_err(|x| x.subscribe())
101        } else {
102            // Need to recheck
103            // During change from read lock to write lock, hashmap may be modified
104            match self.inner.write().await.entry(id) {
105                Entry::Occupied(occupied_entry) => {
106                    occupied_entry.get().clone().map_err(|x| x.subscribe())
107                }
108                Entry::Vacant(vacant_entry) => {
109                    let (s, r) = channel(());
110                    vacant_entry.insert(Err(s));
111                    Err(r)
112                }
113            }
114        }
115    }
116    async fn try_get_socket(&self, id: u16) -> Option<T> {
117        if let Some(r) = self.inner.read().await.get(&id) {
118            match r {
119                Ok(s) => Some(s.clone()),
120                Err(_) => None,
121            }
122        } else {
123            None
124        }
125    }
126    async fn get_socket_or_wait(&self, id: u16) -> Result<T, SError> {
127        match self.get_socket_or_notify(id).await {
128            Ok(r) => Ok(r),
129            Err(mut n) => {
130                // This may fail is UDP session is closed right at this moment.
131                n.changed()
132                    .await
133                    .map_err(|_| SError::UDPSessionClosed("notify sender dropped".to_string()))?;
134                //
135                let ret = self
136                    .try_get_socket(id)
137                    .await
138                    .ok_or(SError::UDPSessionClosed("UDP session closed".to_string()))?;
139                Ok(ret)
140            }
141        }
142    }
143    async fn store_socket(&self, id: u16, val: T) {
144        let mut h = self.inner.write().await;
145        trace!("receiving side alive socket number: {}", h.len());
146        let r = h.get_mut(&id);
147        if let Some(s) = r {
148            match s {
149                Ok(_) => {
150                    error!("id:{} already exists", id);
151                }
152                Err(_) => {
153                    let notify = replace(s, Ok(val));
154                    //let _ = notify.map_err(|x| x.notify_one());
155                    match notify {
156                        Ok(_) => {
157                            panic!("should be notify"); // should never happen
158                        }
159                        Err(n) => {
160                            n.send(()).unwrap();
161                            event!(Level::TRACE, "notify socket id:{}", id);
162                        }
163                    }
164                }
165            }
166        } else {
167            h.insert(id, Ok(val));
168        }
169    }
170    async fn fetch_new_id(&self, val: T) -> u16 {
171        let mut inner = self.inner.write().await;
172        trace!("sending side socket number: {}", inner.len());
173        let mut r;
174        loop {
175            r = self
176                .id_counter
177                .fetch_add(1, std::sync::atomic::Ordering::SeqCst); // Wrapping occured if overflow
178            if let Entry::Vacant(e) = inner.entry(r) {
179                e.insert(Ok(val));
180                break;
181            }
182        }
183        r
184    }
185}
186
187/// AssociateSendSession is a session for sending UDP packets.
188/// It is created for each association task
189/// The local dst_map works as a inverse map from destination to id
190/// When session ended, the ids created by this session will be removed from the IDStore.
191struct AssociateSendSession<W: AsyncWrite> {
192    id_store: IDStore<()>,
193    dst_map: HashMap<SocksAddr, u16>,
194    unistream_map: HashMap<SocksAddr, W>,
195}
196impl<W: AsyncWrite> AssociateSendSession<W> {
197    pub async fn get_id_or_insert(&mut self, addr: &SocksAddr) -> (u16, bool) {
198        if let Some(id) = self.dst_map.get(addr) {
199            (*id, false)
200        } else {
201            let id = self.id_store.fetch_new_id(()).await;
202            self.dst_map.insert(addr.clone(), id);
203            trace!("send session: insert id:{}, addr:{}", id, addr);
204            (id, true)
205        }
206    }
207}
208
209impl<W: AsyncWrite> Drop for AssociateSendSession<W> {
210    fn drop(&mut self) {
211        let id_store = self.id_store.inner.clone();
212        let id_remove = self.dst_map.clone();
213        tokio::spawn(
214            async move {
215                let mut id_store = id_store.write().await;
216                let len = id_store.len();
217                id_remove.values().for_each(|k| {
218                    id_store.remove(k);
219                });
220                let decrease = len - id_store.len();
221                event!(
222                    Level::TRACE,
223                    "AssociateSendSession dropped, session id size:{}, {} ids cleaned",
224                    id_remove.len(),
225                    decrease
226                );
227            }
228            .in_current_span(),
229        );
230    }
231}
232/// AssociateRecvSession is a session for receiving UDP ctrl stream.
233/// It is created for each association task
234/// There are two usages for id_map
235/// First, it works as local cache avoiding using global store repeatedly which is more expensive
236/// Second. it records ids created by this session and clean those ids when session ended.
237struct AssociateRecvSession {
238    id_store: IDStore<(AnyUdpSend, SocksAddr)>,
239    id_map: HashMap<u16, SocksAddr>,
240}
241impl AssociateRecvSession {
242    pub async fn store_socket(&mut self, id: u16, dst: SocksAddr, socks: AnyUdpSend) {
243        if let hash_map::Entry::Vacant(e) = self.id_map.entry(id) {
244            self.id_store.store_socket(id, (socks, dst.clone())).await;
245            trace!("recv session: insert id:{}, addr:{}", id, dst);
246            e.insert(dst);
247        }
248    }
249}
250
251impl Drop for AssociateRecvSession {
252    fn drop(&mut self) {
253        let id_store = self.id_store.inner.clone();
254        let id_remove = self.id_map.clone();
255        tokio::spawn(
256            async move {
257                let mut id_store = id_store.write().await;
258                let len = id_store.len();
259
260                id_remove.keys().for_each(|k| {
261                    id_store.remove(k);
262                });
263                let decrease = len - id_store.len();
264                event!(
265                    Level::TRACE,
266                    "AssociateRecvSession dropped, session id size:{}, {} ids cleaned",
267                    id_remove.len(),
268                    decrease
269                );
270            }
271            .in_current_span(),
272        );
273    }
274}
275
276/// Handle udp packets send
277/// It watches the udp socket and sends the packets to the quic connection.
278/// This function is symetrical for both clients and servers.
279pub async fn handle_udp_send<C: QuicConnection>(
280    mut send: C::SendStream,
281    udp_recv: AnyUdpRecv,
282    conn: SQConn<C>,
283    over_stream: bool,
284) -> Result<(), SError> {
285    let mut down_stream = udp_recv;
286    let mut session = AssociateSendSession {
287        id_store: conn.send_id_store.clone(),
288        dst_map: Default::default(),
289        unistream_map: Default::default(),
290    };
291    let quic_conn = conn.conn.clone();
292    loop {
293        let (bytes, dst) = down_stream.recv_from().await?;
294        let (id, is_new) = session.get_id_or_insert(&dst).await;
295        //let span = trace_span!("udp", id = id);
296        let ctl_header = SQUdpControlHeader {
297            dst: dst.clone(),
298            id,
299        };
300        let dg_header = SQPacketDatagramHeader { id };
301        if over_stream && !session.unistream_map.contains_key(&dst) {
302            let (uni, _id) = conn.open_uni().await?;
303            session.unistream_map.insert(dst.clone(), uni);
304        }
305
306        let fut1 = async {
307            if is_new {
308                ctl_header.encode(&mut send).await?;
309            }
310            //trace!("udp control header sent");
311            Ok(()) as Result<(), SError>
312        };
313        let fut2 = async {
314            let mut content = BytesMut::with_capacity(2000);
315            let mut head = Vec::<u8>::new();
316            dg_header.clone().encode(&mut head).await?;
317
318            if over_stream {
319                // Must be opened and inserted.
320                let conn = session.unistream_map.get_mut(&dst).unwrap();
321                let mut head = Vec::<u8>::new();
322                if is_new {
323                    dg_header.encode(&mut head).await?
324                }
325                (bytes.len() as u16).encode(&mut head).await?;
326                conn.write_all(&head).await?;
327                conn.write_all(&bytes).await?;
328            } else {
329                content.put(Bytes::from(head));
330                content.put(bytes);
331                let content = content.freeze();
332                quic_conn.send_datagram(content).await?;
333            }
334            Ok(())
335        };
336        tokio::try_join!(fut1, fut2)?;
337    }
338    #[allow(unreachable_code)]
339    Ok(())
340}
341
342/// Handle udp ctrl stream receive task
343/// it retrieves the dst id pair from the bistream and records related socket and address
344/// This function is symetrical for both clients and servers.
345pub async fn handle_udp_recv_ctrl<C: QuicConnection>(
346    mut recv: C::RecvStream,
347    udp_socket: AnyUdpSend,
348    conn: SQConn<C>,
349) -> Result<(), SError> {
350    let mut session = AssociateRecvSession {
351        id_store: conn.recv_id_store.clone(),
352        id_map: Default::default(),
353    };
354    loop {
355        let SQUdpControlHeader { id, dst } = SQUdpControlHeader::decode(&mut recv).await?;
356        trace!("udp control header received: id:{},dst:{}", id, dst);
357        session.store_socket(id, dst, udp_socket.clone()).await;
358    }
359    #[allow(unreachable_code)]
360    Ok(())
361}
362
363/// Handle udp packet receive task
364/// It watches udp packets from quic connection and sends them to the udp socket.
365/// The udp socket could be downstream(inbound) or upstream(outbound)
366/// This function is symetrical for both clients and servers.
367pub async fn handle_udp_packet_recv<C: QuicConnection>(conn: SQConn<C>) -> Result<(), SError> {
368    let id_store = conn.recv_id_store.clone();
369
370    wait_sunny_auth(&conn).await?;
371    loop {
372        tokio::select! {
373            b = conn.read_datagram() => {
374                let b = b?;
375                let b = BytesMut::from(b);
376                let mut cur = Cursor::new(b);
377                let SQPacketDatagramHeader{id} = SQPacketDatagramHeader::decode(&mut cur).await?;
378
379                match id_store.get_socket_or_notify(id).await {
380                 Ok((udp,addr)) =>  {
381                    let pos = cur.position() as usize;
382                    let b = cur.into_inner().freeze();
383                    udp.send_to(b.slice(pos..b.len()), addr.clone()).await?;
384                }
385                Err(mut notify) =>  {
386                    let id_store = id_store.clone();
387                    let src_addr = conn.remote_address();
388                    event!(Level::TRACE, "resolving datagram id:{}",id);
389                    // Might spawn too many tasks
390                    tokio::spawn(async move {
391                        // It's safe to sender to be dropped
392                        let _ = notify.changed().await.map_err(|_|debug!("id:{} notifier dropped",id));
393                        // session may be closed
394                        let (udp,addr) = id_store.try_get_socket(id).await.ok_or(SError::UDPSessionClosed("UDP session closed".to_string()))?;
395                        info!("udp over datagram: id:{}: {}->{}",id, src_addr, addr);
396                        let pos = cur.position() as usize;
397                        let b = cur.into_inner().freeze();
398                        let _ = udp.clone().send_to(b.slice(pos..b.len()), addr.clone()).await
399                        .map_err(|x|error!("{}",x));
400                        Ok(()) as Result<(), SError>
401                     }.in_current_span());
402                }
403            }
404            }
405
406            r = async {
407                let (mut uni_stream, _id) = conn.accept_uni().await?;
408                trace!("unistream accepted");
409                let SQPacketDatagramHeader{id} = SQPacketDatagramHeader::decode(&mut uni_stream).await?;
410                event!(Level::TRACE, "resolving datagram id:{}",id);
411
412                let (udp,addr) = id_store.get_socket_or_wait(id).await?;
413
414                info!("udp over stream: id:{}: {}->{}",id, conn.remote_address(), addr);
415                Ok((uni_stream,udp.clone(),addr.clone())) as Result<(C::RecvStream,AnyUdpSend,SocksAddr),SError>
416            } => {
417
418                let  (mut uni_stream,udp,addr) = match r {
419                    Ok(r) => r,
420                    Err(SError::UDPSessionClosed(_)) => {
421                        continue;
422                    }
423                    Err(e) => {
424                        return Err(e);
425                    }
426                };
427
428                tokio::spawn(async move {
429                    loop {
430                        let l: usize = u16::decode(&mut uni_stream).await? as usize;
431                        let mut b = BytesMut::with_capacity(l);
432                        b.resize(l,0);
433                        uni_stream.read_exact(&mut b).await?;
434                        udp.send_to(b.freeze(), addr.clone()).await?;
435                    }
436                    #[allow(unreachable_code)]
437                    (Ok(()) as Result<(), SError>)
438                }.in_current_span());
439            }
440        }
441    }
442    #[allow(unreachable_code)]
443    Ok(())
444}