325 lines
13 KiB
Rust
325 lines
13 KiB
Rust
use std::collections::HashSet;
|
|
use std::future::Future;
|
|
use std::ops::Deref;
|
|
use std::pin::Pin;
|
|
use std::rc::Rc;
|
|
|
|
use crate::constants::inner_constants::ENVIRONMENT_SERVICE;
|
|
use crate::models::user::User;
|
|
use crate::DbPool;
|
|
use actix::fut::ok;
|
|
use actix_web::body::{EitherBody, MessageBody};
|
|
use actix_web::error::{ErrorForbidden, ErrorUnauthorized};
|
|
use actix_web::{
|
|
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
|
web, Error, HttpMessage,
|
|
};
|
|
use base64::engine::general_purpose;
|
|
use base64::Engine;
|
|
|
|
use futures_util::future::{LocalBoxFuture, Ready};
|
|
use futures_util::FutureExt;
|
|
use jsonwebtoken::jwk::Jwk;
|
|
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
|
use log::info;
|
|
use serde_json::Value;
|
|
use sha256::digest;
|
|
|
|
pub struct AuthFilter {}
|
|
|
|
impl AuthFilter {
|
|
pub fn new() -> Self {
|
|
AuthFilter {}
|
|
}
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub struct AuthFilterMiddleware<S> {
|
|
service: Rc<S>,
|
|
}
|
|
|
|
impl<S, B> Transform<S, ServiceRequest> for AuthFilter
|
|
where
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
|
|
S::Future: 'static,
|
|
B: MessageBody + 'static,
|
|
{
|
|
type Response = ServiceResponse<EitherBody<B>>;
|
|
type Error = Error;
|
|
type Transform = AuthFilterMiddleware<S>;
|
|
type InitError = ();
|
|
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
|
|
|
fn new_transform(&self, service: S) -> Self::Future {
|
|
ok(AuthFilterMiddleware {
|
|
service: Rc::new(service),
|
|
})
|
|
}
|
|
}
|
|
|
|
impl<S, B> Service<ServiceRequest> for AuthFilterMiddleware<S>
|
|
where
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
|
|
S::Future: 'static,
|
|
B: MessageBody + 'static,
|
|
{
|
|
type Response = ServiceResponse<EitherBody<B>>;
|
|
type Error = Error;
|
|
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
|
|
|
forward_ready!(service);
|
|
|
|
fn call(&self, req: ServiceRequest) -> Self::Future {
|
|
if ENVIRONMENT_SERVICE.get().unwrap().http_basic {
|
|
self.handle_basic_auth(req)
|
|
} else if ENVIRONMENT_SERVICE.get().unwrap().oidc_configured {
|
|
self.handle_oidc_auth(req)
|
|
} else if ENVIRONMENT_SERVICE.get().unwrap().reverse_proxy {
|
|
self.handle_proxy_auth(req)
|
|
} else {
|
|
// It can only be no auth
|
|
self.handle_no_auth(req)
|
|
}
|
|
}
|
|
}
|
|
|
|
type MyFuture<B, Error> =
|
|
Pin<Box<dyn Future<Output = Result<ServiceResponse<EitherBody<B>>, Error>>>>;
|
|
|
|
impl<S, B> AuthFilterMiddleware<S>
|
|
where
|
|
B: 'static + MessageBody,
|
|
S: 'static + Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
|
S::Future: 'static,
|
|
{
|
|
fn handle_basic_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
|
|
let env_service = ENVIRONMENT_SERVICE.get().unwrap();
|
|
let opt_auth_header = req.headers().get("Authorization");
|
|
|
|
match opt_auth_header {
|
|
Some(header) => match header.to_str() {
|
|
Ok(auth) => {
|
|
let (username, password) = AuthFilter::extract_basic_auth(auth);
|
|
let res = req.app_data::<web::Data<DbPool>>().unwrap();
|
|
let found_user =
|
|
User::find_by_username(username.as_str(), &mut res.get().unwrap());
|
|
|
|
if found_user.is_err() {
|
|
return Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body()));
|
|
}
|
|
let unwrapped_user = found_user.unwrap();
|
|
|
|
if let Some(admin_username) = env_service.username.clone() {
|
|
if unwrapped_user.username.clone() == admin_username {
|
|
return match env_service.password.is_some()
|
|
&& digest(password) == env_service.password.clone().unwrap()
|
|
{
|
|
true => {
|
|
req.extensions_mut().insert(unwrapped_user);
|
|
let service = Rc::clone(&self.service);
|
|
async move {
|
|
service.call(req).await.map(|res| res.map_into_left_body())
|
|
}
|
|
.boxed_local()
|
|
}
|
|
false => Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body())),
|
|
};
|
|
}
|
|
}
|
|
|
|
if unwrapped_user.password.clone().unwrap() == digest(password) {
|
|
req.extensions_mut().insert(unwrapped_user);
|
|
let service = Rc::clone(&self.service);
|
|
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
|
|
.boxed_local()
|
|
} else {
|
|
Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body()))
|
|
}
|
|
}
|
|
Err(_) => Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body())),
|
|
},
|
|
None => Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body())),
|
|
}
|
|
}
|
|
|
|
fn handle_oidc_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
|
|
let token_res = req.headers().get("Authorization").unwrap().to_str();
|
|
if token_res.is_err() {
|
|
return Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body()));
|
|
}
|
|
let token = token_res.unwrap().replace("Bearer ", "");
|
|
|
|
let jwk = req.app_data::<web::Data<Option<Jwk>>>().cloned().unwrap();
|
|
|
|
// Create a DecodingKey from a PEM-encoded RSA string
|
|
|
|
let key = DecodingKey::from_jwk(&jwk.as_ref().clone().unwrap()).unwrap();
|
|
let mut validation = Validation::new(Algorithm::RS256);
|
|
validation.aud = Some(
|
|
req.app_data::<web::Data<HashSet<String>>>()
|
|
.unwrap()
|
|
.clone()
|
|
.into_inner()
|
|
.deref()
|
|
.clone(),
|
|
);
|
|
|
|
return match decode::<Value>(&token, &key, &validation) {
|
|
Ok(decoded) => {
|
|
let username = decoded
|
|
.claims
|
|
.get("preferred_username")
|
|
.unwrap()
|
|
.as_str()
|
|
.unwrap();
|
|
let pool = req.app_data::<web::Data<DbPool>>().cloned().unwrap();
|
|
let found_user = User::find_by_username(username, &mut pool.get().unwrap());
|
|
let service = Rc::clone(&self.service);
|
|
|
|
match found_user {
|
|
Ok(user) => {
|
|
req.extensions_mut().insert(user);
|
|
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
|
|
.boxed_local()
|
|
}
|
|
Err(_) => {
|
|
// User is authenticated so we can onboard him if he is new
|
|
let user = User::insert_user(
|
|
&mut User {
|
|
id: 0,
|
|
username: decoded
|
|
.claims
|
|
.get("preferred_username")
|
|
.unwrap()
|
|
.as_str()
|
|
.unwrap()
|
|
.to_string(),
|
|
role: "user".to_string(),
|
|
password: None,
|
|
explicit_consent: false,
|
|
created_at: chrono::Utc::now().naive_utc(),
|
|
api_key: None,
|
|
},
|
|
&mut pool.get().unwrap(),
|
|
)
|
|
.expect("Error inserting user");
|
|
req.extensions_mut().insert(user);
|
|
async move { service.call(req).await.map(|res| res.map_into_left_body()) }
|
|
.boxed_local()
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
info!("Error decoding token: {:?}", e);
|
|
Box::pin(ok(req
|
|
.error_response(ErrorForbidden("Forbidden"))
|
|
.map_into_right_body()))
|
|
}
|
|
};
|
|
}
|
|
|
|
fn handle_no_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
|
|
let user = User::create_standard_admin_user();
|
|
req.extensions_mut().insert(user);
|
|
let service = Rc::clone(&self.service);
|
|
async move { service.call(req).await.map(|res| res.map_into_left_body()) }.boxed_local()
|
|
}
|
|
|
|
fn handle_proxy_auth(&self, req: ServiceRequest) -> MyFuture<B, Error> {
|
|
let config = ENVIRONMENT_SERVICE
|
|
.get()
|
|
.unwrap()
|
|
.reverse_proxy_config
|
|
.clone()
|
|
.unwrap();
|
|
|
|
let header_val = req.headers().get(config.header_name);
|
|
|
|
if let Some(header_val) = header_val {
|
|
let token_res = header_val.to_str();
|
|
return match token_res {
|
|
Ok(token) => {
|
|
let pool = req.app_data::<web::Data<DbPool>>().cloned().unwrap();
|
|
let found_user = User::find_by_username(token, &mut pool.get().unwrap());
|
|
let service = Rc::clone(&self.service);
|
|
|
|
return match found_user {
|
|
Ok(user) => {
|
|
req.extensions_mut().insert(user);
|
|
return async move {
|
|
service.call(req).await.map(|res| res.map_into_left_body())
|
|
}
|
|
.boxed_local();
|
|
}
|
|
Err(_) => {
|
|
if config.auto_sign_up {
|
|
let user = User::insert_user(
|
|
&mut User {
|
|
id: 0,
|
|
username: token.to_string(),
|
|
role: "user".to_string(),
|
|
password: None,
|
|
explicit_consent: false,
|
|
created_at: chrono::Utc::now().naive_utc(),
|
|
api_key: None,
|
|
},
|
|
&mut pool.get().unwrap(),
|
|
)
|
|
.expect("Error inserting user");
|
|
req.extensions_mut().insert(user);
|
|
return async move {
|
|
service.call(req).await.map(|res| res.map_into_left_body())
|
|
}
|
|
.boxed_local();
|
|
} else {
|
|
Box::pin(ok(req
|
|
.error_response(ErrorForbidden("Forbidden"))
|
|
.map_into_right_body()))
|
|
}
|
|
}
|
|
};
|
|
}
|
|
Err(_) => Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body())),
|
|
};
|
|
}
|
|
|
|
Box::pin(ok(req
|
|
.error_response(ErrorUnauthorized("Unauthorized"))
|
|
.map_into_right_body()))
|
|
}
|
|
}
|
|
|
|
impl AuthFilter {
|
|
pub fn extract_basic_auth(auth: &str) -> (String, String) {
|
|
let auth = auth.to_string();
|
|
let auth = auth.split(' ').collect::<Vec<&str>>();
|
|
let auth = auth[1];
|
|
let auth = general_purpose::STANDARD.decode(auth).unwrap();
|
|
let auth = String::from_utf8(auth).unwrap();
|
|
let auth = auth.split(':').collect::<Vec<&str>>();
|
|
let username = auth[0];
|
|
let password = auth[1];
|
|
(username.to_string(), password.to_string())
|
|
}
|
|
|
|
pub fn basic_auth_login(rq: String) -> (String, String) {
|
|
let (u, p) = Self::extract_basic_auth(rq.as_str());
|
|
|
|
(u.to_string(), p.to_string())
|
|
}
|
|
}
|