1use std::{
6 collections::{
7 HashMap,
8 hash_map::{self, Entry},
9 },
10 io::Cursor,
11 mem::replace,
12 ops::Deref,
13 sync::{Arc, atomic::AtomicU16},
14 time::Duration,
15};
16
17use bytes::{BufMut, Bytes, BytesMut};
18use tokio::{
19 io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
20 sync::{
21 RwLock, SetOnce,
22 watch::{Receiver, Sender, channel},
23 },
24};
25use tracing::{Instrument, Level, debug, error, event, info, trace};
26
27use crate::{
28 AnyUdpRecv, AnyUdpSend,
29 error::{SError, SResult},
30 msgs::squic::SunnyCredential,
31 msgs::{
32 SDecode, SEncode,
33 socks5::SocksAddr,
34 squic::{SQPacketDatagramHeader, SQReq, SQUdpControlHeader},
35 },
36 quic::QuicConnection,
37};
38
39pub mod inbound;
40pub mod outbound;
41
42#[derive(Clone)]
46pub struct SQConn<T: QuicConnection> {
47 pub(crate) conn: T,
48 pub(crate) authed: Arc<SetOnce<bool>>,
49 pub(crate) send_id_store: IDStore<()>,
50 pub(crate) recv_id_store: IDStore<(AnyUdpSend, SocksAddr)>,
51}
52
53async fn wait_sunny_auth<T: QuicConnection>(conn: &SQConn<T>) -> SResult<()> {
54 match tokio::time::timeout(Duration::from_millis(3200), conn.authed.wait()).await {
55 Ok(true) => Ok(()),
56 Ok(false) => Err(SError::SunnyAuthError("Wrong psassword/username".into())),
57 Err(_) => Err(SError::SunnyAuthError("timeout".into())),
58 }
59}
60
61pub(crate) async fn auth_sunny<T: QuicConnection>(
62 conn: &SQConn<T>,
63 user_hash: SunnyCredential,
64) -> SResult<()> {
65 if conn.authed.get().is_none() {
66 let (mut send, _recv, _id) = conn.open_bi().await?;
67 SQReq::SQAuthenticate(user_hash).encode(&mut send).await?;
68 debug!("authentication request sent");
69 conn.authed.set(true).expect("repeated authentication");
70 }
71 Ok(())
72}
73
74impl<T: QuicConnection> Deref for SQConn<T> {
75 type Target = T;
76
77 fn deref(&self) -> &Self::Target {
78 &self.conn
79 }
80}
81
82type IDStoreVal<T> = Result<T, Sender<()>>;
85#[derive(Clone, Default)]
89pub(crate) struct IDStore<T = (AnyUdpSend, SocksAddr)> {
90 pub(crate) id_counter: Arc<AtomicU16>,
91 pub(crate) inner: Arc<RwLock<HashMap<u16, IDStoreVal<T>>>>,
92}
93
94impl<T> IDStore<T>
95where
96 T: Clone,
97{
98 async fn get_socket_or_notify(&self, id: u16) -> Result<T, Receiver<()>> {
99 if let Some(r) = self.inner.read().await.get(&id) {
100 r.clone().map_err(|x| x.subscribe())
101 } else {
102 match self.inner.write().await.entry(id) {
105 Entry::Occupied(occupied_entry) => {
106 occupied_entry.get().clone().map_err(|x| x.subscribe())
107 }
108 Entry::Vacant(vacant_entry) => {
109 let (s, r) = channel(());
110 vacant_entry.insert(Err(s));
111 Err(r)
112 }
113 }
114 }
115 }
116 async fn try_get_socket(&self, id: u16) -> Option<T> {
117 if let Some(r) = self.inner.read().await.get(&id) {
118 match r {
119 Ok(s) => Some(s.clone()),
120 Err(_) => None,
121 }
122 } else {
123 None
124 }
125 }
126 async fn get_socket_or_wait(&self, id: u16) -> Result<T, SError> {
127 match self.get_socket_or_notify(id).await {
128 Ok(r) => Ok(r),
129 Err(mut n) => {
130 n.changed()
132 .await
133 .map_err(|_| SError::UDPSessionClosed("notify sender dropped".to_string()))?;
134 let ret = self
136 .try_get_socket(id)
137 .await
138 .ok_or(SError::UDPSessionClosed("UDP session closed".to_string()))?;
139 Ok(ret)
140 }
141 }
142 }
143 async fn store_socket(&self, id: u16, val: T) {
144 let mut h = self.inner.write().await;
145 trace!("receiving side alive socket number: {}", h.len());
146 let r = h.get_mut(&id);
147 if let Some(s) = r {
148 match s {
149 Ok(_) => {
150 error!("id:{} already exists", id);
151 }
152 Err(_) => {
153 let notify = replace(s, Ok(val));
154 match notify {
156 Ok(_) => {
157 panic!("should be notify"); }
159 Err(n) => {
160 n.send(()).unwrap();
161 event!(Level::TRACE, "notify socket id:{}", id);
162 }
163 }
164 }
165 }
166 } else {
167 h.insert(id, Ok(val));
168 }
169 }
170 async fn fetch_new_id(&self, val: T) -> u16 {
171 let mut inner = self.inner.write().await;
172 trace!("sending side socket number: {}", inner.len());
173 let mut r;
174 loop {
175 r = self
176 .id_counter
177 .fetch_add(1, std::sync::atomic::Ordering::SeqCst); if let Entry::Vacant(e) = inner.entry(r) {
179 e.insert(Ok(val));
180 break;
181 }
182 }
183 r
184 }
185}
186
187struct AssociateSendSession<W: AsyncWrite> {
192 id_store: IDStore<()>,
193 dst_map: HashMap<SocksAddr, u16>,
194 unistream_map: HashMap<SocksAddr, W>,
195}
196impl<W: AsyncWrite> AssociateSendSession<W> {
197 pub async fn get_id_or_insert(&mut self, addr: &SocksAddr) -> (u16, bool) {
198 if let Some(id) = self.dst_map.get(addr) {
199 (*id, false)
200 } else {
201 let id = self.id_store.fetch_new_id(()).await;
202 self.dst_map.insert(addr.clone(), id);
203 trace!("send session: insert id:{}, addr:{}", id, addr);
204 (id, true)
205 }
206 }
207}
208
209impl<W: AsyncWrite> Drop for AssociateSendSession<W> {
210 fn drop(&mut self) {
211 let id_store = self.id_store.inner.clone();
212 let id_remove = self.dst_map.clone();
213 tokio::spawn(
214 async move {
215 let mut id_store = id_store.write().await;
216 let len = id_store.len();
217 id_remove.values().for_each(|k| {
218 id_store.remove(k);
219 });
220 let decrease = len - id_store.len();
221 event!(
222 Level::TRACE,
223 "AssociateSendSession dropped, session id size:{}, {} ids cleaned",
224 id_remove.len(),
225 decrease
226 );
227 }
228 .in_current_span(),
229 );
230 }
231}
232struct AssociateRecvSession {
238 id_store: IDStore<(AnyUdpSend, SocksAddr)>,
239 id_map: HashMap<u16, SocksAddr>,
240}
241impl AssociateRecvSession {
242 pub async fn store_socket(&mut self, id: u16, dst: SocksAddr, socks: AnyUdpSend) {
243 if let hash_map::Entry::Vacant(e) = self.id_map.entry(id) {
244 self.id_store.store_socket(id, (socks, dst.clone())).await;
245 trace!("recv session: insert id:{}, addr:{}", id, dst);
246 e.insert(dst);
247 }
248 }
249}
250
251impl Drop for AssociateRecvSession {
252 fn drop(&mut self) {
253 let id_store = self.id_store.inner.clone();
254 let id_remove = self.id_map.clone();
255 tokio::spawn(
256 async move {
257 let mut id_store = id_store.write().await;
258 let len = id_store.len();
259
260 id_remove.keys().for_each(|k| {
261 id_store.remove(k);
262 });
263 let decrease = len - id_store.len();
264 event!(
265 Level::TRACE,
266 "AssociateRecvSession dropped, session id size:{}, {} ids cleaned",
267 id_remove.len(),
268 decrease
269 );
270 }
271 .in_current_span(),
272 );
273 }
274}
275
276pub async fn handle_udp_send<C: QuicConnection>(
280 mut send: C::SendStream,
281 udp_recv: AnyUdpRecv,
282 conn: SQConn<C>,
283 over_stream: bool,
284) -> Result<(), SError> {
285 let mut down_stream = udp_recv;
286 let mut session = AssociateSendSession {
287 id_store: conn.send_id_store.clone(),
288 dst_map: Default::default(),
289 unistream_map: Default::default(),
290 };
291 let quic_conn = conn.conn.clone();
292 loop {
293 let (bytes, dst) = down_stream.recv_from().await?;
294 let (id, is_new) = session.get_id_or_insert(&dst).await;
295 let ctl_header = SQUdpControlHeader {
297 dst: dst.clone(),
298 id,
299 };
300 let dg_header = SQPacketDatagramHeader { id };
301 if over_stream && !session.unistream_map.contains_key(&dst) {
302 let (uni, _id) = conn.open_uni().await?;
303 session.unistream_map.insert(dst.clone(), uni);
304 }
305
306 let fut1 = async {
307 if is_new {
308 ctl_header.encode(&mut send).await?;
309 }
310 Ok(()) as Result<(), SError>
312 };
313 let fut2 = async {
314 let mut content = BytesMut::with_capacity(2000);
315 let mut head = Vec::<u8>::new();
316 dg_header.clone().encode(&mut head).await?;
317
318 if over_stream {
319 let conn = session.unistream_map.get_mut(&dst).unwrap();
321 let mut head = Vec::<u8>::new();
322 if is_new {
323 dg_header.encode(&mut head).await?
324 }
325 (bytes.len() as u16).encode(&mut head).await?;
326 conn.write_all(&head).await?;
327 conn.write_all(&bytes).await?;
328 } else {
329 content.put(Bytes::from(head));
330 content.put(bytes);
331 let content = content.freeze();
332 quic_conn.send_datagram(content).await?;
333 }
334 Ok(())
335 };
336 tokio::try_join!(fut1, fut2)?;
337 }
338 #[allow(unreachable_code)]
339 Ok(())
340}
341
342pub async fn handle_udp_recv_ctrl<C: QuicConnection>(
346 mut recv: C::RecvStream,
347 udp_socket: AnyUdpSend,
348 conn: SQConn<C>,
349) -> Result<(), SError> {
350 let mut session = AssociateRecvSession {
351 id_store: conn.recv_id_store.clone(),
352 id_map: Default::default(),
353 };
354 loop {
355 let SQUdpControlHeader { id, dst } = SQUdpControlHeader::decode(&mut recv).await?;
356 trace!("udp control header received: id:{},dst:{}", id, dst);
357 session.store_socket(id, dst, udp_socket.clone()).await;
358 }
359 #[allow(unreachable_code)]
360 Ok(())
361}
362
363pub async fn handle_udp_packet_recv<C: QuicConnection>(conn: SQConn<C>) -> Result<(), SError> {
368 let id_store = conn.recv_id_store.clone();
369
370 wait_sunny_auth(&conn).await?;
371 loop {
372 tokio::select! {
373 b = conn.read_datagram() => {
374 let b = b?;
375 let b = BytesMut::from(b);
376 let mut cur = Cursor::new(b);
377 let SQPacketDatagramHeader{id} = SQPacketDatagramHeader::decode(&mut cur).await?;
378
379 match id_store.get_socket_or_notify(id).await {
380 Ok((udp,addr)) => {
381 let pos = cur.position() as usize;
382 let b = cur.into_inner().freeze();
383 udp.send_to(b.slice(pos..b.len()), addr.clone()).await?;
384 }
385 Err(mut notify) => {
386 let id_store = id_store.clone();
387 let src_addr = conn.remote_address();
388 event!(Level::TRACE, "resolving datagram id:{}",id);
389 tokio::spawn(async move {
391 let _ = notify.changed().await.map_err(|_|debug!("id:{} notifier dropped",id));
393 let (udp,addr) = id_store.try_get_socket(id).await.ok_or(SError::UDPSessionClosed("UDP session closed".to_string()))?;
395 info!("udp over datagram: id:{}: {}->{}",id, src_addr, addr);
396 let pos = cur.position() as usize;
397 let b = cur.into_inner().freeze();
398 let _ = udp.clone().send_to(b.slice(pos..b.len()), addr.clone()).await
399 .map_err(|x|error!("{}",x));
400 Ok(()) as Result<(), SError>
401 }.in_current_span());
402 }
403 }
404 }
405
406 r = async {
407 let (mut uni_stream, _id) = conn.accept_uni().await?;
408 trace!("unistream accepted");
409 let SQPacketDatagramHeader{id} = SQPacketDatagramHeader::decode(&mut uni_stream).await?;
410 event!(Level::TRACE, "resolving datagram id:{}",id);
411
412 let (udp,addr) = id_store.get_socket_or_wait(id).await?;
413
414 info!("udp over stream: id:{}: {}->{}",id, conn.remote_address(), addr);
415 Ok((uni_stream,udp.clone(),addr.clone())) as Result<(C::RecvStream,AnyUdpSend,SocksAddr),SError>
416 } => {
417
418 let (mut uni_stream,udp,addr) = match r {
419 Ok(r) => r,
420 Err(SError::UDPSessionClosed(_)) => {
421 continue;
422 }
423 Err(e) => {
424 return Err(e);
425 }
426 };
427
428 tokio::spawn(async move {
429 loop {
430 let l: usize = u16::decode(&mut uni_stream).await? as usize;
431 let mut b = BytesMut::with_capacity(l);
432 b.resize(l,0);
433 uni_stream.read_exact(&mut b).await?;
434 udp.send_to(b.freeze(), addr.clone()).await?;
435 }
436 #[allow(unreachable_code)]
437 (Ok(()) as Result<(), SError>)
438 }.in_current_span());
439 }
440 }
441 }
442 #[allow(unreachable_code)]
443 Ok(())
444}