diff --git a/server/src/account.rs b/server/src/account.rs index 3b56c25c..9b44c2c1 100644 --- a/server/src/account.rs +++ b/server/src/account.rs @@ -12,6 +12,7 @@ use names::{name as generate_name}; use construct::{Construct, construct_recover, construct_spawn}; use instance::{Instance, instance_delete}; use mtx::{Mtx, FREE_MTX}; +use pg::Db; use failure::Error; use failure::{err_msg, format_err}; @@ -48,7 +49,7 @@ pub fn select(tx: &mut Transaction, id: Uuid) -> Result { Ok(Account { id, name: row.get(1), credits, subscribed }) } -pub fn from_token(tx: &mut Transaction, token: String) -> Result { +pub fn from_token(db: &Db, token: String) -> Result { let query = " SELECT id, name, subscribed, credits FROM accounts @@ -56,7 +57,7 @@ pub fn from_token(tx: &mut Transaction, token: String) -> Result AND token_expiry > now(); "; - let result = tx + let result = db .query(query, &[&token])?; let row = result.iter().next() diff --git a/server/src/main.rs b/server/src/main.rs index 9f7f3460..e8398423 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -25,6 +25,7 @@ extern crate persistent; extern crate router; extern crate cookie; extern crate tungstenite; +extern crate crossbeam_channel; mod account; mod construct; diff --git a/server/src/net.rs b/server/src/net.rs index ef4e78da..3d01572a 100644 --- a/server/src/net.rs +++ b/server/src/net.rs @@ -137,7 +137,7 @@ fn login_res(token: String) -> Response { let v = Cookie::build(TOKEN_HEADER, token) .http_only(true) .same_site(SameSite::Strict) - .max_age(Duration::seconds(-1)) // 1 week aligns with db set + .max_age(Duration::weeks(1)) // 1 week aligns with db set .finish(); let mut res = Response::with(status::Ok); diff --git a/server/src/rpc.rs b/server/src/rpc.rs index e7fee12f..12d1fc83 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -66,7 +66,7 @@ enum RpcRequest { VboxReclaim { instance_id: Uuid, index: usize }, } -pub fn receive(data: Vec, db: &Db, _client: &mut WebSocket, begin: Instant, account: Option<&Account>) -> Result { +pub fn receive(data: Vec, db: &Db, _client: &mut WebSocket, begin: Instant, account: &Option) -> Result { // cast the msg to this type to receive method name match from_slice::(&data) { Ok(v) => { diff --git a/server/src/ws.rs b/server/src/ws.rs index 2b01c209..e63901b7 100644 --- a/server/src/ws.rs +++ b/server/src/ws.rs @@ -1,10 +1,16 @@ use std::time::{Instant}; use std::net::{TcpListener}; use std::thread::{spawn}; +use std::str; + +use cookie::Cookie; use tungstenite::server::accept_hdr; use tungstenite::Message::Binary; -use tungstenite::handshake::server::Request; +use tungstenite::handshake::server::{Request, ErrorResponse}; +use tungstenite::http::StatusCode; + +use crossbeam_channel::{unbounded, Sender}; use failure::Error; use failure::err_msg; @@ -12,8 +18,9 @@ use failure::err_msg; use serde_cbor::{to_vec}; use net::TOKEN_HEADER; -use rpc::{receive}; +use rpc; use pg::PgPool; +use account; #[derive(Debug,Clone,Serialize)] struct RpcError { @@ -25,42 +32,88 @@ pub fn start(pool: PgPool) { for stream in ws_server.incoming() { let ws_pool = pool.clone(); spawn(move || { + + let (acc_s, acc_r) = unbounded(); + let cb = |req: &Request| { - let token = req.headers.find_first(TOKEN_HEADER); - println!("{:?}", token); + let err = || ErrorResponse { + error_code: StatusCode::FORBIDDEN, + headers: None, + body: Some("Unauthorized".into()), + }; + + if let Some(cl) = req.headers.find_first("Cookie") { + let cookie_list = str::from_utf8(cl).or(Err(err()))?; + + for s in cookie_list.split(";").map(|s| s.trim()) { + let cookie = Cookie::parse(s).or(Err(err()))?; + + // got auth token + if cookie.name() == TOKEN_HEADER { + info!("{:?}", cookie.value().to_string()); + acc_s.send(Some(cookie.value().to_string())).or(Err(err()))?; + } + }; + }; + acc_s.send(None).unwrap(); Ok(None) }; + let mut websocket = accept_hdr(stream.unwrap(), cb).unwrap(); + + let account = match acc_r.recv().unwrap() { + Some(t) => { + let db = ws_pool.get() + .expect("unable to get db connection"); + match account::from_token(&db, t) { + Ok(a) => { + let state = to_vec(&rpc::RpcResult::AccountState(a.clone())).unwrap(); + websocket.write_message(Binary(state)).unwrap(); + Some(a) + }, + Err(e) => { + warn!("{:?}", e); + return; + }, + } + }, + None => None, + }; + loop { match websocket.read_message() { Ok(msg) => { - let begin = Instant::now(); - let db_connection = ws_pool.get() - .expect("unable to get db connection"); + match msg { + Binary(data) => { + let begin = Instant::now(); + let db_connection = ws_pool.get() + .expect("unable to get db connection"); - let data = msg.into_data(); - match receive(data, &db_connection, &mut websocket, begin, None) { - Ok(reply) => { - let response = to_vec(&reply) - .expect("failed to serialize response"); + match rpc::receive(data, &db_connection, &mut websocket, begin, &account) { + Ok(reply) => { + let response = to_vec(&reply) + .expect("failed to serialize response"); - if let Err(e) = websocket.write_message(Binary(response)) { - // connection closed - debug!("{:?}", e); - return; - }; + if let Err(e) = websocket.write_message(Binary(response)) { + // connection closed + debug!("{:?}", e); + return; + }; + }, + Err(e) => { + warn!("{:?}", e); + let response = to_vec(&RpcError { err: e.to_string() }) + .expect("failed to serialize error response"); + + if let Err(e) = websocket.write_message(Binary(response)) { + // connection closed + debug!("{:?}", e); + return; + }; + } + } }, - Err(e) => { - warn!("{:?}", e); - let response = to_vec(&RpcError { err: e.to_string() }) - .expect("failed to serialize error response"); - - if let Err(e) = websocket.write_message(Binary(response)) { - // connection closed - debug!("{:?}", e); - return; - }; - } + _ => (), } }, // connection is closed