Initial Commit
This commit is contained in:
317
src/main.rs
Normal file
317
src/main.rs
Normal file
@ -0,0 +1,317 @@
|
||||
mod cache;
|
||||
mod gamma;
|
||||
mod rules;
|
||||
|
||||
use compound_error::CompoundError;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Client, Request, Response, Server, Uri};
|
||||
use kv_log_macro::{debug, error, info, warn};
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use cache::UserCache;
|
||||
use gamma::{Credentials, User};
|
||||
use rules::{Method, Permission, Rule};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Opt {
|
||||
/// Example: http://my-server:80/
|
||||
proxy: String,
|
||||
|
||||
/// Example: https://gamma.chalmers.it
|
||||
gamma: String,
|
||||
|
||||
/// Example: "My special place"
|
||||
realm: String,
|
||||
|
||||
rules: HashMap<String, Rule>,
|
||||
}
|
||||
|
||||
#[derive(Debug, CompoundError)]
|
||||
enum Error {
|
||||
Hyper(hyper::Error),
|
||||
IO(std::io::Error),
|
||||
}
|
||||
struct State {
|
||||
opt: Opt,
|
||||
cache: UserCache,
|
||||
}
|
||||
|
||||
async fn proxy_pass(mut req: Request<Body>, state: &State) -> Result<Response<Body>, Error> {
|
||||
let req_uri = req.uri().clone();
|
||||
|
||||
info!("{:#?}", req);
|
||||
|
||||
let unauthorized = |msg| {
|
||||
info!("respoinding with 401 Unauthorized due to {}", msg);
|
||||
Ok(Response::builder()
|
||||
.status(401)
|
||||
.header(
|
||||
"WWW-Authenticate",
|
||||
format!(r#"Basic realm="{}", charset="UTF-8""#, state.opt.realm),
|
||||
)
|
||||
.header("Docker-Distribution-Api-Version", "registry/2.0")
|
||||
.body(Body::from("Unauthorized"))
|
||||
.expect("infallible response"))
|
||||
};
|
||||
|
||||
let forbidden = |msg| {
|
||||
info!("respoinding with 403 Forbidden due to {}", msg);
|
||||
Ok(Response::builder()
|
||||
.status(403)
|
||||
.body(Body::from("Forbidden"))
|
||||
.expect("infallible response"))
|
||||
};
|
||||
|
||||
let login = match req.headers_mut().remove("Authorization") {
|
||||
None => {
|
||||
info!("request received", {
|
||||
uri: format!("{}", req_uri).as_str(),
|
||||
method: req.method().as_str(),
|
||||
provided_auth: false,
|
||||
});
|
||||
None
|
||||
}
|
||||
Some(authorization) => {
|
||||
info!("request received", {
|
||||
uri: format!("{}", req_uri).as_str(),
|
||||
method: req.method().as_str(),
|
||||
provided_auth: true,
|
||||
});
|
||||
|
||||
let credentials = authorization
|
||||
.to_str()
|
||||
.ok()
|
||||
.and_then(|s| s.strip_prefix("Basic "))
|
||||
.and_then(|s| base64::decode(s).ok())
|
||||
.and_then(|b| String::from_utf8(b).ok());
|
||||
|
||||
match credentials.as_ref().and_then(|s| s.split_once(":")) {
|
||||
Some((user, pass)) => {
|
||||
let credentials = Credentials {
|
||||
username: user.to_string(),
|
||||
password: pass.to_string(),
|
||||
};
|
||||
match state.cache.login(&state.opt, &credentials).await {
|
||||
Ok(user) => Some(user),
|
||||
Err(e) => {
|
||||
warn!("{}", e);
|
||||
return unauthorized("invalid login");
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("client did not provide valid a \"Authorization\" header");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match validate(&state.opt, &req, &login) {
|
||||
Validation::Allowed => {
|
||||
info!("success"); /* success! continue to proxy */
|
||||
}
|
||||
Validation::NotAllowed => return forbidden("failed validation"),
|
||||
Validation::RequiresLogin => return unauthorized("not logged in"),
|
||||
}
|
||||
|
||||
let proxy_uri: Uri = state.opt.proxy.parse().expect("proxy uri");
|
||||
let mut new_uri = Uri::builder().authority(proxy_uri.authority().unwrap().clone());
|
||||
new_uri = new_uri.scheme(proxy_uri.scheme_str().unwrap_or("http"));
|
||||
|
||||
if let Some(paq) = req_uri.path_and_query().cloned() {
|
||||
new_uri = new_uri.path_and_query(paq);
|
||||
}
|
||||
|
||||
*req.uri_mut() = new_uri.build().expect("uri");
|
||||
|
||||
let client = Client::new();
|
||||
let mut error = None;
|
||||
let response = match client.request(req).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
error = Some(format!("{:?}", e));
|
||||
Response::builder()
|
||||
.status(503)
|
||||
.body("503 Service Unavailable".into())
|
||||
.expect("infallible response")
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(e) = error {
|
||||
warn!("{}", e);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn get_opt() -> Result<Opt, String> {
|
||||
let env = |name| std::env::var(name).map_err(|e| format!("{}: {}", name, e));
|
||||
|
||||
Ok(Opt {
|
||||
proxy: env("PROXY_HOST")?,
|
||||
realm: env("AUTH_REALM")?,
|
||||
gamma: env("GAMMA_HOST")?,
|
||||
rules: std::env::vars()
|
||||
.filter(|(name, _)| name.starts_with("AUTH_RULE_"))
|
||||
.map(|(name, rule)| Rule::parse(&rule).map(|rule| (name, rule)))
|
||||
.collect::<Result<_, _>>()?,
|
||||
})
|
||||
}
|
||||
|
||||
enum Validation {
|
||||
/// Validataion succeeded, client pay proceed
|
||||
Allowed,
|
||||
|
||||
/// The client is not allowed to proceed, even if it provided credentials
|
||||
NotAllowed,
|
||||
|
||||
/// The client should provide credentials and try again
|
||||
RequiresLogin,
|
||||
}
|
||||
|
||||
fn validate(opt: &Opt, req: &Request<Body>, login: &Option<User>) -> Validation {
|
||||
// filter irrelevant rules
|
||||
let mut applicable_rules: Vec<_> = opt
|
||||
.rules
|
||||
.iter()
|
||||
.filter(|(_, rule)| {
|
||||
let req_path: PathBuf = req.uri().path().into();
|
||||
req_path.starts_with(&rule.path)
|
||||
})
|
||||
.filter(|(_, rule)| match rule.method {
|
||||
Method::Any => true,
|
||||
Method::GET => req.method() == hyper::Method::GET,
|
||||
Method::POST => req.method() == hyper::Method::POST,
|
||||
Method::PUT => req.method() == hyper::Method::PUT,
|
||||
Method::DELETE => req.method() == hyper::Method::DELETE,
|
||||
Method::HEAD => req.method() == hyper::Method::HEAD,
|
||||
Method::OPTIONS => req.method() == hyper::Method::OPTIONS,
|
||||
Method::CONNECT => req.method() == hyper::Method::CONNECT,
|
||||
Method::PATCH => req.method() == hyper::Method::PATCH,
|
||||
Method::TRACE => req.method() == hyper::Method::TRACE,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// find the most relevant rule
|
||||
applicable_rules.sort_by(|&(a_name, a), &(b_name, b)| {
|
||||
// example priorities:
|
||||
// 1. /foo/bar/baz/ POST
|
||||
// 2. /foo/bar/baz/ Any
|
||||
// 3. /foo/bar Any
|
||||
|
||||
let a_len = a.path.components().count();
|
||||
let b_len = b.path.components().count();
|
||||
|
||||
a_len.cmp(&b_len).then_with(|| match (a.method, b.method) {
|
||||
(ma, mb) if ma == mb => panic!("conflicting rules, {} and {}", a_name, b_name),
|
||||
(Method::Any, _) => Ordering::Less,
|
||||
(_, Method::Any) => Ordering::Greater,
|
||||
_ => unreachable!("Rules for different methods can't conflict"),
|
||||
})
|
||||
});
|
||||
|
||||
debug!("list of applicable rules: {:?}", applicable_rules);
|
||||
|
||||
let (name, rule) = match applicable_rules.last() {
|
||||
Some(&last) => last,
|
||||
|
||||
// No rules exist, so we default to not allowed
|
||||
None => return Validation::NotAllowed,
|
||||
};
|
||||
|
||||
info!("checking connecting against rule {}", name);
|
||||
|
||||
fn check_logged_in(login: &Option<User>) -> Validation {
|
||||
if login.is_some() {
|
||||
Validation::Allowed
|
||||
} else {
|
||||
Validation::RequiresLogin
|
||||
}
|
||||
}
|
||||
|
||||
fn check_is_group(group: &str, login: &Option<User>) -> Validation {
|
||||
match login.as_ref() {
|
||||
Some(user) => {
|
||||
info!("check_is_group", { user: user.username.as_str(), group: group });
|
||||
let mut groups = user
|
||||
.groups
|
||||
.iter()
|
||||
.flat_map(|group| [&group.name, &group.super_group.name]);
|
||||
|
||||
if groups.find(|&user_group| user_group == group).is_some() {
|
||||
Validation::Allowed
|
||||
} else {
|
||||
Validation::NotAllowed
|
||||
}
|
||||
}
|
||||
None => Validation::RequiresLogin,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_permission(permission: &Permission, login: &Option<User>) -> Validation {
|
||||
let recurse = |perm| check_permission(perm, login);
|
||||
|
||||
match permission {
|
||||
Permission::AllowAll => Validation::Allowed,
|
||||
Permission::AnyUser => check_logged_in(login),
|
||||
Permission::Group(group) => check_is_group(&group, login),
|
||||
Permission::And(a, b) => match (recurse(a), recurse(b)) {
|
||||
(Validation::Allowed, Validation::Allowed) => Validation::Allowed,
|
||||
(Validation::NotAllowed, _) | (_, Validation::NotAllowed) => Validation::NotAllowed,
|
||||
(Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => {
|
||||
Validation::RequiresLogin
|
||||
}
|
||||
},
|
||||
Permission::Or(a, b) => match (recurse(a), recurse(b)) {
|
||||
(Validation::NotAllowed, Validation::NotAllowed) => Validation::NotAllowed,
|
||||
(Validation::Allowed, _) | (_, Validation::Allowed) => Validation::Allowed,
|
||||
(Validation::RequiresLogin, _) | (_, Validation::RequiresLogin) => {
|
||||
Validation::RequiresLogin
|
||||
}
|
||||
},
|
||||
Permission::Not(rule) => match recurse(rule) {
|
||||
Validation::Allowed => Validation::NotAllowed,
|
||||
Validation::NotAllowed | Validation::RequiresLogin => Validation::Allowed,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// validate request against the rule
|
||||
check_permission(&rule.permission, login)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
//femme::start();
|
||||
femme::with_level(femme::LevelFilter::Debug);
|
||||
|
||||
let opt: Opt = match get_opt() {
|
||||
Ok(opt) => opt,
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
info!("{:#?}", opt);
|
||||
|
||||
let cache = UserCache::new();
|
||||
|
||||
let state: &'static State = Box::leak(Box::new(State { opt, cache }));
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
||||
info!("Listening on 0.0.0.0:3000");
|
||||
|
||||
let make_svc = make_service_fn(|_conn| async move {
|
||||
Ok::<_, Error>(service_fn(move |r| proxy_pass(r, state)))
|
||||
});
|
||||
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
|
||||
if let Err(e) = server.await {
|
||||
error!("server error: {}", e);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user