mod cache; mod gamma; mod rules; use compound_error::CompoundError; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, Uri}; use kv_log_macro::{debug, error, info, warn}; use std::cmp::Ordering; use std::collections::HashMap; use std::net::SocketAddr; use std::path::PathBuf; use cache::UserCache; use gamma::{Credentials, User}; use rules::{Method, Permission, Rule}; #[derive(Debug)] struct Opt { /// Example: http://my-server:80/ proxy: String, /// Example: https://gamma.chalmers.it gamma: String, /// Example: "My special place" realm: String, rules: HashMap, } #[derive(Debug, CompoundError)] enum Error { Hyper(hyper::Error), IO(std::io::Error), } struct State { opt: Opt, cache: UserCache, } async fn proxy_pass(mut req: Request, state: &State) -> Result, Error> { let req_uri = req.uri().clone(); info!("{:#?}", req); let unauthorized = |msg| { info!("respoinding with 401 Unauthorized due to {}", msg); Ok(Response::builder() .status(401) .header( "WWW-Authenticate", format!(r#"Basic realm="{}", charset="UTF-8""#, state.opt.realm), ) .header("Docker-Distribution-Api-Version", "registry/2.0") .body(Body::from("Unauthorized")) .expect("infallible response")) }; let forbidden = |msg| { info!("respoinding with 403 Forbidden due to {}", msg); Ok(Response::builder() .status(403) .body(Body::from("Forbidden")) .expect("infallible response")) }; let login = match req.headers_mut().remove("Authorization") { None => { info!("request received", { uri: format!("{}", req_uri).as_str(), method: req.method().as_str(), provided_auth: false, }); None } Some(authorization) => { info!("request received", { uri: format!("{}", req_uri).as_str(), method: req.method().as_str(), provided_auth: true, }); let credentials = authorization .to_str() .ok() .and_then(|s| s.strip_prefix("Basic ")) .and_then(|s| base64::decode(s).ok()) .and_then(|b| String::from_utf8(b).ok()); match credentials.as_ref().and_then(|s| s.split_once(":")) { Some((user, pass)) => { let credentials = Credentials { username: user.to_string(), password: pass.to_string(), }; match state.cache.login(&state.opt, &credentials).await { Ok(user) => Some(user), Err(e) => { warn!("{}", e); return unauthorized("invalid login"); } } } None => { warn!("client did not provide valid a \"Authorization\" header"); None } } } }; match validate(&state.opt, &req, &login) { Validation::Allowed => { info!("success"); /* success! continue to proxy */ } Validation::NotAllowed => return forbidden("failed validation"), Validation::RequiresLogin => return unauthorized("not logged in"), } let proxy_uri: Uri = state.opt.proxy.parse().expect("proxy uri"); let mut new_uri = Uri::builder().authority(proxy_uri.authority().unwrap().clone()); new_uri = new_uri.scheme(proxy_uri.scheme_str().unwrap_or("http")); if let Some(paq) = req_uri.path_and_query().cloned() { new_uri = new_uri.path_and_query(paq); } *req.uri_mut() = new_uri.build().expect("uri"); let client = Client::new(); let mut error = None; let response = match client.request(req).await { Ok(response) => response, Err(e) => { error = Some(format!("{:?}", e)); Response::builder() .status(503) .body("503 Service Unavailable".into()) .expect("infallible response") } }; if let Some(e) = error { warn!("{}", e); } Ok(response) } fn get_opt() -> Result { let env = |name| std::env::var(name).map_err(|e| format!("{}: {}", name, e)); Ok(Opt { proxy: env("PROXY_HOST")?, realm: env("AUTH_REALM")?, gamma: env("GAMMA_HOST")?, rules: std::env::vars() .filter(|(name, _)| name.starts_with("AUTH_RULE_")) .map(|(name, rule)| Rule::parse(&rule).map(|rule| (name, rule))) .collect::>()?, }) } enum Validation { /// Validataion succeeded, client pay proceed Allowed, /// The client is not allowed to proceed, even if it provided credentials NotAllowed, /// The client should provide credentials and try again RequiresLogin, } fn validate(opt: &Opt, req: &Request, login: &Option) -> Validation { // filter irrelevant rules let mut applicable_rules: Vec<_> = opt .rules .iter() .filter(|(_, rule)| { let req_path: PathBuf = req.uri().path().into(); req_path.starts_with(&rule.path) }) .filter(|(_, rule)| match rule.method { Method::Any => true, Method::GET => req.method() == hyper::Method::GET, Method::POST => req.method() == hyper::Method::POST, Method::PUT => req.method() == hyper::Method::PUT, Method::DELETE => req.method() == hyper::Method::DELETE, Method::HEAD => req.method() == hyper::Method::HEAD, Method::OPTIONS => req.method() == hyper::Method::OPTIONS, Method::CONNECT => req.method() == hyper::Method::CONNECT, Method::PATCH => req.method() == hyper::Method::PATCH, Method::TRACE => req.method() == hyper::Method::TRACE, }) .collect(); // find the most relevant rule applicable_rules.sort_by(|&(a_name, a), &(b_name, b)| { // example priorities: // 1. /foo/bar/baz/ POST // 2. /foo/bar/baz/ Any // 3. /foo/bar Any let a_len = a.path.components().count(); let b_len = b.path.components().count(); a_len.cmp(&b_len).then_with(|| match (a.method, b.method) { (ma, mb) if ma == mb => panic!("conflicting rules, {} and {}", a_name, b_name), (Method::Any, _) => Ordering::Less, (_, Method::Any) => Ordering::Greater, _ => unreachable!("Rules for different methods can't conflict"), }) }); debug!("list of applicable rules: {:?}", applicable_rules); let (name, rule) = match applicable_rules.last() { Some(&last) => last, // No rules exist, so we default to not allowed None => return Validation::NotAllowed, }; info!("checking connecting against rule {}", name); fn check_logged_in(login: &Option) -> Validation { if login.is_some() { Validation::Allowed } else { Validation::RequiresLogin } } fn check_is_group(group: &str, login: &Option) -> Validation { match login.as_ref() { Some(user) => { info!("check_is_group", { user: user.username.as_str(), group: group }); let mut groups = user .groups .iter() .flat_map(|group| [&group.name, &group.super_group.name]); if groups.find(|&user_group| user_group == group).is_some() { Validation::Allowed } else { Validation::NotAllowed } } None => Validation::RequiresLogin, } } fn check_permission(permission: &Permission, login: &Option) -> Validation { let recurse = |perm| check_permission(perm, login); match permission { Permission::AllowAll => Validation::Allowed, Permission::AnyUser => check_logged_in(login), Permission::Group(group) => check_is_group(&group, login), Permission::And(a, b) => match (recurse(a), recurse(b)) { (Validation::Allowed, Validation::Allowed) => Validation::Allowed, (Validation::NotAllowed, _) | (_, Validation::NotAllowed) => Validation::NotAllowed, (Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => { Validation::RequiresLogin } }, Permission::Or(a, b) => match (recurse(a), recurse(b)) { (Validation::NotAllowed, Validation::NotAllowed) => Validation::NotAllowed, (Validation::Allowed, _) | (_, Validation::Allowed) => Validation::Allowed, (Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => { Validation::RequiresLogin } }, Permission::Not(rule) => match recurse(rule) { Validation::Allowed => Validation::NotAllowed, Validation::NotAllowed | Validation::RequiresLogin => Validation::Allowed, }, } } // validate request against the rule check_permission(&rule.permission, login) } #[tokio::main] async fn main() { //femme::start(); femme::with_level(femme::LevelFilter::Debug); let opt: Opt = match get_opt() { Ok(opt) => opt, Err(e) => { error!("{}", e); std::process::exit(1); } }; info!("{:#?}", opt); let cache = UserCache::new(); let state: &'static State = Box::leak(Box::new(State { opt, cache })); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); info!("Listening on 0.0.0.0:3000"); let make_svc = make_service_fn(|_conn| async move { Ok::<_, Error>(service_fn(move |r| proxy_pass(r, state))) }); let server = Server::bind(&addr).serve(make_svc); if let Err(e) = server.await { error!("server error: {}", e); } }