318 lines
10 KiB
Rust
318 lines
10 KiB
Rust
mod cache;
|
|
mod gamma;
|
|
mod rules;
|
|
|
|
use compound_error::CompoundError;
|
|
use hyper::service::{make_service_fn, service_fn};
|
|
use hyper::{Body, Client, Request, Response, Server, Uri};
|
|
use kv_log_macro::{debug, error, info, warn};
|
|
use std::cmp::Ordering;
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::path::PathBuf;
|
|
|
|
use cache::UserCache;
|
|
use gamma::{Credentials, User};
|
|
use rules::{Method, Permission, Rule};
|
|
|
|
#[derive(Debug)]
|
|
struct Opt {
|
|
/// Example: http://my-server:80/
|
|
proxy: String,
|
|
|
|
/// Example: https://gamma.chalmers.it
|
|
gamma: String,
|
|
|
|
/// Example: "My special place"
|
|
realm: String,
|
|
|
|
rules: HashMap<String, Rule>,
|
|
}
|
|
|
|
#[derive(Debug, CompoundError)]
|
|
enum Error {
|
|
Hyper(hyper::Error),
|
|
IO(std::io::Error),
|
|
}
|
|
struct State {
|
|
opt: Opt,
|
|
cache: UserCache,
|
|
}
|
|
|
|
async fn proxy_pass(mut req: Request<Body>, state: &State) -> Result<Response<Body>, Error> {
|
|
let req_uri = req.uri().clone();
|
|
|
|
info!("{:#?}", req);
|
|
|
|
let unauthorized = |msg| {
|
|
info!("respoinding with 401 Unauthorized due to {}", msg);
|
|
Ok(Response::builder()
|
|
.status(401)
|
|
.header(
|
|
"WWW-Authenticate",
|
|
format!(r#"Basic realm="{}", charset="UTF-8""#, state.opt.realm),
|
|
)
|
|
.header("Docker-Distribution-Api-Version", "registry/2.0")
|
|
.body(Body::from("Unauthorized"))
|
|
.expect("infallible response"))
|
|
};
|
|
|
|
let forbidden = |msg| {
|
|
info!("respoinding with 403 Forbidden due to {}", msg);
|
|
Ok(Response::builder()
|
|
.status(403)
|
|
.body(Body::from("Forbidden"))
|
|
.expect("infallible response"))
|
|
};
|
|
|
|
let login = match req.headers_mut().remove("Authorization") {
|
|
None => {
|
|
info!("request received", {
|
|
uri: format!("{}", req_uri).as_str(),
|
|
method: req.method().as_str(),
|
|
provided_auth: false,
|
|
});
|
|
None
|
|
}
|
|
Some(authorization) => {
|
|
info!("request received", {
|
|
uri: format!("{}", req_uri).as_str(),
|
|
method: req.method().as_str(),
|
|
provided_auth: true,
|
|
});
|
|
|
|
let credentials = authorization
|
|
.to_str()
|
|
.ok()
|
|
.and_then(|s| s.strip_prefix("Basic "))
|
|
.and_then(|s| base64::decode(s).ok())
|
|
.and_then(|b| String::from_utf8(b).ok());
|
|
|
|
match credentials.as_ref().and_then(|s| s.split_once(":")) {
|
|
Some((user, pass)) => {
|
|
let credentials = Credentials {
|
|
username: user.to_string(),
|
|
password: pass.to_string(),
|
|
};
|
|
match state.cache.login(&state.opt, &credentials).await {
|
|
Ok(user) => Some(user),
|
|
Err(e) => {
|
|
warn!("{}", e);
|
|
return unauthorized("invalid login");
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
warn!("client did not provide valid a \"Authorization\" header");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
match validate(&state.opt, &req, &login) {
|
|
Validation::Allowed => {
|
|
info!("success"); /* success! continue to proxy */
|
|
}
|
|
Validation::NotAllowed => return forbidden("failed validation"),
|
|
Validation::RequiresLogin => return unauthorized("not logged in"),
|
|
}
|
|
|
|
let proxy_uri: Uri = state.opt.proxy.parse().expect("proxy uri");
|
|
let mut new_uri = Uri::builder().authority(proxy_uri.authority().unwrap().clone());
|
|
new_uri = new_uri.scheme(proxy_uri.scheme_str().unwrap_or("http"));
|
|
|
|
if let Some(paq) = req_uri.path_and_query().cloned() {
|
|
new_uri = new_uri.path_and_query(paq);
|
|
}
|
|
|
|
*req.uri_mut() = new_uri.build().expect("uri");
|
|
|
|
let client = Client::new();
|
|
let mut error = None;
|
|
let response = match client.request(req).await {
|
|
Ok(response) => response,
|
|
Err(e) => {
|
|
error = Some(format!("{:?}", e));
|
|
Response::builder()
|
|
.status(503)
|
|
.body("503 Service Unavailable".into())
|
|
.expect("infallible response")
|
|
}
|
|
};
|
|
|
|
if let Some(e) = error {
|
|
warn!("{}", e);
|
|
}
|
|
|
|
Ok(response)
|
|
}
|
|
|
|
fn get_opt() -> Result<Opt, String> {
|
|
let env = |name| std::env::var(name).map_err(|e| format!("{}: {}", name, e));
|
|
|
|
Ok(Opt {
|
|
proxy: env("PROXY_HOST")?,
|
|
realm: env("AUTH_REALM")?,
|
|
gamma: env("GAMMA_HOST")?,
|
|
rules: std::env::vars()
|
|
.filter(|(name, _)| name.starts_with("AUTH_RULE_"))
|
|
.map(|(name, rule)| Rule::parse(&rule).map(|rule| (name, rule)))
|
|
.collect::<Result<_, _>>()?,
|
|
})
|
|
}
|
|
|
|
enum Validation {
|
|
/// Validataion succeeded, client pay proceed
|
|
Allowed,
|
|
|
|
/// The client is not allowed to proceed, even if it provided credentials
|
|
NotAllowed,
|
|
|
|
/// The client should provide credentials and try again
|
|
RequiresLogin,
|
|
}
|
|
|
|
fn validate(opt: &Opt, req: &Request<Body>, login: &Option<User>) -> Validation {
|
|
// filter irrelevant rules
|
|
let mut applicable_rules: Vec<_> = opt
|
|
.rules
|
|
.iter()
|
|
.filter(|(_, rule)| {
|
|
let req_path: PathBuf = req.uri().path().into();
|
|
req_path.starts_with(&rule.path)
|
|
})
|
|
.filter(|(_, rule)| match rule.method {
|
|
Method::Any => true,
|
|
Method::GET => req.method() == hyper::Method::GET,
|
|
Method::POST => req.method() == hyper::Method::POST,
|
|
Method::PUT => req.method() == hyper::Method::PUT,
|
|
Method::DELETE => req.method() == hyper::Method::DELETE,
|
|
Method::HEAD => req.method() == hyper::Method::HEAD,
|
|
Method::OPTIONS => req.method() == hyper::Method::OPTIONS,
|
|
Method::CONNECT => req.method() == hyper::Method::CONNECT,
|
|
Method::PATCH => req.method() == hyper::Method::PATCH,
|
|
Method::TRACE => req.method() == hyper::Method::TRACE,
|
|
})
|
|
.collect();
|
|
|
|
// find the most relevant rule
|
|
applicable_rules.sort_by(|&(a_name, a), &(b_name, b)| {
|
|
// example priorities:
|
|
// 1. /foo/bar/baz/ POST
|
|
// 2. /foo/bar/baz/ Any
|
|
// 3. /foo/bar Any
|
|
|
|
let a_len = a.path.components().count();
|
|
let b_len = b.path.components().count();
|
|
|
|
a_len.cmp(&b_len).then_with(|| match (a.method, b.method) {
|
|
(ma, mb) if ma == mb => panic!("conflicting rules, {} and {}", a_name, b_name),
|
|
(Method::Any, _) => Ordering::Less,
|
|
(_, Method::Any) => Ordering::Greater,
|
|
_ => unreachable!("Rules for different methods can't conflict"),
|
|
})
|
|
});
|
|
|
|
debug!("list of applicable rules: {:?}", applicable_rules);
|
|
|
|
let (name, rule) = match applicable_rules.last() {
|
|
Some(&last) => last,
|
|
|
|
// No rules exist, so we default to not allowed
|
|
None => return Validation::NotAllowed,
|
|
};
|
|
|
|
info!("checking connecting against rule {}", name);
|
|
|
|
fn check_logged_in(login: &Option<User>) -> Validation {
|
|
if login.is_some() {
|
|
Validation::Allowed
|
|
} else {
|
|
Validation::RequiresLogin
|
|
}
|
|
}
|
|
|
|
fn check_is_group(group: &str, login: &Option<User>) -> Validation {
|
|
match login.as_ref() {
|
|
Some(user) => {
|
|
info!("check_is_group", { user: user.username.as_str(), group: group });
|
|
let mut groups = user
|
|
.groups
|
|
.iter()
|
|
.flat_map(|group| [&group.name, &group.super_group.name]);
|
|
|
|
if groups.find(|&user_group| user_group == group).is_some() {
|
|
Validation::Allowed
|
|
} else {
|
|
Validation::NotAllowed
|
|
}
|
|
}
|
|
None => Validation::RequiresLogin,
|
|
}
|
|
}
|
|
|
|
fn check_permission(permission: &Permission, login: &Option<User>) -> Validation {
|
|
let recurse = |perm| check_permission(perm, login);
|
|
|
|
match permission {
|
|
Permission::AllowAll => Validation::Allowed,
|
|
Permission::AnyUser => check_logged_in(login),
|
|
Permission::Group(group) => check_is_group(&group, login),
|
|
Permission::And(a, b) => match (recurse(a), recurse(b)) {
|
|
(Validation::Allowed, Validation::Allowed) => Validation::Allowed,
|
|
(Validation::NotAllowed, _) | (_, Validation::NotAllowed) => Validation::NotAllowed,
|
|
(Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => {
|
|
Validation::RequiresLogin
|
|
}
|
|
},
|
|
Permission::Or(a, b) => match (recurse(a), recurse(b)) {
|
|
(Validation::NotAllowed, Validation::NotAllowed) => Validation::NotAllowed,
|
|
(Validation::Allowed, _) | (_, Validation::Allowed) => Validation::Allowed,
|
|
(Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => {
|
|
Validation::RequiresLogin
|
|
}
|
|
},
|
|
Permission::Not(rule) => match recurse(rule) {
|
|
Validation::Allowed => Validation::NotAllowed,
|
|
Validation::NotAllowed | Validation::RequiresLogin => Validation::Allowed,
|
|
},
|
|
}
|
|
}
|
|
|
|
// validate request against the rule
|
|
check_permission(&rule.permission, login)
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
//femme::start();
|
|
femme::with_level(femme::LevelFilter::Debug);
|
|
|
|
let opt: Opt = match get_opt() {
|
|
Ok(opt) => opt,
|
|
Err(e) => {
|
|
error!("{}", e);
|
|
std::process::exit(1);
|
|
}
|
|
};
|
|
info!("{:#?}", opt);
|
|
|
|
let cache = UserCache::new();
|
|
|
|
let state: &'static State = Box::leak(Box::new(State { opt, cache }));
|
|
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
|
info!("Listening on 0.0.0.0:3000");
|
|
|
|
let make_svc = make_service_fn(|_conn| async move {
|
|
Ok::<_, Error>(service_fn(move |r| proxy_pass(r, state)))
|
|
});
|
|
|
|
let server = Server::bind(&addr).serve(make_svc);
|
|
|
|
if let Err(e) = server.await {
|
|
error!("server error: {}", e);
|
|
}
|
|
}
|