Apply Black on all the code base

This commit is contained in:
decentral1se 2019-11-25 17:21:13 +01:00 committed by Alexandre Aubin
parent 6f5daa0c38
commit 54b8cab133
27 changed files with 1171 additions and 875 deletions

View file

@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-
from moulinette.core import init_interface, MoulinetteError, MoulinetteSignals, Moulinette18n
from moulinette.core import (
init_interface,
MoulinetteError,
MoulinetteSignals,
Moulinette18n,
)
from moulinette.globals import init_moulinette_env
__title__ = 'moulinette'
__version__ = '0.1'
__author__ = ['Kload',
'jlebleu',
'titoko',
'beudbeud',
'npze']
__license__ = 'AGPL 3.0'
__title__ = "moulinette"
__version__ = "0.1"
__author__ = ["Kload", "jlebleu", "titoko", "beudbeud", "npze"]
__license__ = "AGPL 3.0"
__credits__ = """
Copyright (C) 2014 YUNOHOST.ORG
@ -28,8 +29,13 @@ __credits__ = """
along with this program; if not, see http://www.gnu.org/licenses
"""
__all__ = [
'init', 'api', 'cli', 'm18n', 'env',
'init_interface', 'MoulinetteError',
"init",
"api",
"cli",
"m18n",
"env",
"init_interface",
"MoulinetteError",
]
@ -40,6 +46,7 @@ m18n = Moulinette18n()
# Package functions
def init(logging_config=None, **kwargs):
"""Package initialization
@ -61,13 +68,15 @@ def init(logging_config=None, **kwargs):
configure_logging(logging_config)
# Add library directory to python path
sys.path.insert(0, init_moulinette_env()['LIB_DIR'])
sys.path.insert(0, init_moulinette_env()["LIB_DIR"])
# Easy access to interfaces
def api(namespaces, host='localhost', port=80, routes={},
use_websocket=True, use_cache=True):
def api(
namespaces, host="localhost", port=80, routes={}, use_websocket=True, use_cache=True
):
"""Web server (API) interface
Run a HTTP server with the moulinette for an API usage.
@ -84,29 +93,33 @@ def api(namespaces, host='localhost', port=80, routes={},
"""
try:
moulinette = init_interface('api',
kwargs={
'routes': routes,
'use_websocket': use_websocket
},
actionsmap={
'namespaces': namespaces,
'use_cache': use_cache
}
moulinette = init_interface(
"api",
kwargs={"routes": routes, "use_websocket": use_websocket},
actionsmap={"namespaces": namespaces, "use_cache": use_cache},
)
moulinette.run(host, port)
except MoulinetteError as e:
import logging
logging.getLogger(namespaces[0]).error(e.strerror)
return e.errno if hasattr(e, "errno") else 1
except KeyboardInterrupt:
import logging
logging.getLogger(namespaces[0]).info(m18n.g('operation_interrupted'))
logging.getLogger(namespaces[0]).info(m18n.g("operation_interrupted"))
return 0
def cli(namespaces, args, use_cache=True, output_as=None,
password=None, timeout=None, parser_kwargs={}):
def cli(
namespaces,
args,
use_cache=True,
output_as=None,
password=None,
timeout=None,
parser_kwargs={},
):
"""Command line interface
Execute an action with the moulinette from the CLI and print its
@ -125,16 +138,18 @@ def cli(namespaces, args, use_cache=True, output_as=None,
"""
try:
moulinette = init_interface('cli',
moulinette = init_interface(
"cli",
actionsmap={
'namespaces': namespaces,
'use_cache': use_cache,
'parser_kwargs': parser_kwargs,
"namespaces": namespaces,
"use_cache": use_cache,
"parser_kwargs": parser_kwargs,
},
)
moulinette.run(args, output_as=output_as, password=password, timeout=timeout)
except MoulinetteError as e:
import logging
logging.getLogger(namespaces[0]).error(e.strerror)
return 1
return 0

View file

@ -12,19 +12,18 @@ from importlib import import_module
from moulinette import m18n, msignals
from moulinette.cache import open_cachefile
from moulinette.globals import init_moulinette_env
from moulinette.core import (MoulinetteError, MoulinetteLock)
from moulinette.interfaces import (
BaseActionsMapParser, GLOBAL_SECTION, TO_RETURN_PROP
)
from moulinette.core import MoulinetteError, MoulinetteLock
from moulinette.interfaces import BaseActionsMapParser, GLOBAL_SECTION, TO_RETURN_PROP
from moulinette.utils.log import start_action_logging
logger = logging.getLogger('moulinette.actionsmap')
logger = logging.getLogger("moulinette.actionsmap")
# Extra parameters ----------------------------------------------------
# Extra parameters definition
class _ExtraParameter(object):
"""
@ -87,7 +86,7 @@ class _ExtraParameter(object):
class CommentParameter(_ExtraParameter):
name = "comment"
skipped_iface = ['api']
skipped_iface = ["api"]
def __call__(self, message, arg_name, arg_value):
return msignals.display(m18n.n(message))
@ -96,12 +95,15 @@ class CommentParameter(_ExtraParameter):
def validate(klass, value, arg_name):
# Deprecated boolean or empty string
if isinstance(value, bool) or (isinstance(value, str) and not value):
logger.warning("expecting a non-empty string for extra parameter '%s' of "
"argument '%s'", klass.name, arg_name)
logger.warning(
"expecting a non-empty string for extra parameter '%s' of "
"argument '%s'",
klass.name,
arg_name,
)
value = arg_name
elif not isinstance(value, str):
raise TypeError("parameter value must be a string, got %r"
% value)
raise TypeError("parameter value must be a string, got %r" % value)
return value
@ -114,8 +116,9 @@ class AskParameter(_ExtraParameter):
when asking the argument value.
"""
name = 'ask'
skipped_iface = ['api']
name = "ask"
skipped_iface = ["api"]
def __call__(self, message, arg_name, arg_value):
if arg_value:
@ -131,12 +134,15 @@ class AskParameter(_ExtraParameter):
def validate(klass, value, arg_name):
# Deprecated boolean or empty string
if isinstance(value, bool) or (isinstance(value, str) and not value):
logger.warning("expecting a non-empty string for extra parameter '%s' of "
"argument '%s'", klass.name, arg_name)
logger.warning(
"expecting a non-empty string for extra parameter '%s' of "
"argument '%s'",
klass.name,
arg_name,
)
value = arg_name
elif not isinstance(value, str):
raise TypeError("parameter value must be a string, got %r"
% value)
raise TypeError("parameter value must be a string, got %r" % value)
return value
@ -149,7 +155,8 @@ class PasswordParameter(AskParameter):
when asking the password.
"""
name = 'password'
name = "password"
def __call__(self, message, arg_name, arg_value):
if arg_value:
@ -171,40 +178,45 @@ class PatternParameter(_ExtraParameter):
the message to display if it doesn't match.
"""
name = 'pattern'
name = "pattern"
def __call__(self, arguments, arg_name, arg_value):
pattern, message = (arguments[0], arguments[1])
# Use temporarly utf-8 encoded value
try:
v = unicode(arg_value, 'utf-8')
v = unicode(arg_value, "utf-8")
except:
v = arg_value
if v and not re.match(pattern, v or '', re.UNICODE):
logger.debug("argument value '%s' for '%s' doesn't match pattern '%s'",
v, arg_name, pattern)
if v and not re.match(pattern, v or "", re.UNICODE):
logger.debug(
"argument value '%s' for '%s' doesn't match pattern '%s'",
v,
arg_name,
pattern,
)
# Attempt to retrieve message translation
msg = m18n.n(message)
if msg == message:
msg = m18n.g(message)
raise MoulinetteError('invalid_argument',
argument=arg_name, error=msg)
raise MoulinetteError("invalid_argument", argument=arg_name, error=msg)
return arg_value
@staticmethod
def validate(value, arg_name):
# Deprecated string type
if isinstance(value, str):
logger.warning("expecting a list as extra parameter 'pattern' of "
"argument '%s'", arg_name)
value = [value, 'pattern_not_match']
logger.warning(
"expecting a list as extra parameter 'pattern' of " "argument '%s'",
arg_name,
)
value = [value, "pattern_not_match"]
elif not isinstance(value, list) or len(value) != 2:
raise TypeError("parameter value must be a list, got %r"
% value)
raise TypeError("parameter value must be a list, got %r" % value)
return value
@ -216,21 +228,19 @@ class RequiredParameter(_ExtraParameter):
The value of this parameter must be a boolean which is set to False by
default.
"""
name = 'required'
name = "required"
def __call__(self, required, arg_name, arg_value):
if required and (arg_value is None or arg_value == ''):
logger.debug("argument '%s' is required",
arg_name)
raise MoulinetteError('argument_required',
argument=arg_name)
if required and (arg_value is None or arg_value == ""):
logger.debug("argument '%s' is required", arg_name)
raise MoulinetteError("argument_required", argument=arg_name)
return arg_value
@staticmethod
def validate(value, arg_name):
if not isinstance(value, bool):
raise TypeError("parameter value must be a list, got %r"
% value)
raise TypeError("parameter value must be a list, got %r" % value)
return value
@ -239,8 +249,13 @@ The list of available extra parameters classes. It will keep to this list
order on argument parsing.
"""
extraparameters_list = [CommentParameter, AskParameter, PasswordParameter,
RequiredParameter, PatternParameter]
extraparameters_list = [
CommentParameter,
AskParameter,
PasswordParameter,
RequiredParameter,
PatternParameter,
]
# Extra parameters argument Parser
@ -265,7 +280,7 @@ class ExtraArgumentParser(object):
if iface in klass.skipped_iface:
continue
self.extra[klass.name] = klass
logger.debug('extra parameter classes loaded: %s', self.extra.keys())
logger.debug("extra parameter classes loaded: %s", self.extra.keys())
def validate(self, arg_name, parameters):
"""
@ -287,9 +302,14 @@ class ExtraArgumentParser(object):
# Validate parameter value
parameters[p] = klass.validate(v, arg_name)
except Exception as e:
logger.error("unable to validate extra parameter '%s' "
"for argument '%s': %s", p, arg_name, e)
raise MoulinetteError('error_see_log')
logger.error(
"unable to validate extra parameter '%s' "
"for argument '%s': %s",
p,
arg_name,
e,
)
raise MoulinetteError("error_see_log")
return parameters
@ -354,12 +374,15 @@ class ExtraArgumentParser(object):
# Main class ----------------------------------------------------------
def ordered_yaml_load(stream):
class OrderedLoader(yaml.Loader):
pass
OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
lambda loader, node: OrderedDict(loader.construct_pairs(node)))
lambda loader, node: OrderedDict(loader.construct_pairs(node)),
)
return yaml.load(stream, OrderedLoader)
@ -387,16 +410,15 @@ class ActionsMap(object):
"""
def __init__(self, parser_class, namespaces=[], use_cache=True,
parser_kwargs={}):
def __init__(self, parser_class, namespaces=[], use_cache=True, parser_kwargs={}):
if not issubclass(parser_class, BaseActionsMapParser):
raise ValueError("Invalid parser class '%s'" % parser_class.__name__)
self.parser_class = parser_class
self.use_cache = use_cache
moulinette_env = init_moulinette_env()
DATA_DIR = moulinette_env['DATA_DIR']
CACHE_DIR = moulinette_env['CACHE_DIR']
DATA_DIR = moulinette_env["DATA_DIR"]
CACHE_DIR = moulinette_env["CACHE_DIR"]
if len(namespaces) == 0:
namespaces = self.get_namespaces()
@ -406,13 +428,13 @@ class ActionsMap(object):
for n in namespaces:
logger.debug("loading actions map namespace '%s'", n)
actionsmap_yml = '%s/actionsmap/%s.yml' % (DATA_DIR, n)
actionsmap_yml = "%s/actionsmap/%s.yml" % (DATA_DIR, n)
actionsmap_yml_stat = os.stat(actionsmap_yml)
actionsmap_pkl = '%s/actionsmap/%s-%d-%d.pkl' % (
actionsmap_pkl = "%s/actionsmap/%s-%d-%d.pkl" % (
CACHE_DIR,
n,
actionsmap_yml_stat.st_size,
actionsmap_yml_stat.st_mtime
actionsmap_yml_stat.st_mtime,
)
if use_cache and os.path.exists(actionsmap_pkl):
@ -447,16 +469,18 @@ class ActionsMap(object):
# Fetch the configuration for the authenticator module as defined in the actionmap
try:
auth_conf = self.parser.global_conf['authenticator'][auth_profile]
auth_conf = self.parser.global_conf["authenticator"][auth_profile]
except KeyError:
raise ValueError("Unknown authenticator profile '%s'" % auth_profile)
# Load and initialize the authenticator module
try:
mod = import_module('moulinette.authenticators.%s' % auth_conf["vendor"])
mod = import_module("moulinette.authenticators.%s" % auth_conf["vendor"])
except ImportError:
logger.exception("unable to load authenticator vendor '%s'", auth_conf["vendor"])
raise MoulinetteError('error_see_log')
logger.exception(
"unable to load authenticator vendor '%s'", auth_conf["vendor"]
)
raise MoulinetteError("error_see_log")
else:
return mod.Authenticator(**auth_conf)
@ -471,7 +495,7 @@ class ActionsMap(object):
auth = msignals.authenticate(authenticator)
if not auth.is_authenticated:
raise MoulinetteError('authentication_required_long')
raise MoulinetteError("authentication_required_long")
def process(self, args, timeout=None, **kwargs):
"""
@ -492,7 +516,7 @@ class ActionsMap(object):
arguments = vars(self.parser.parse_args(args, **kwargs))
# Retrieve tid and parse arguments with extra parameters
tid = arguments.pop('_tid')
tid = arguments.pop("_tid")
arguments = self.extraparser.parse_args(tid, arguments)
# Return immediately if a value is defined
@ -502,38 +526,55 @@ class ActionsMap(object):
# Retrieve action information
if len(tid) == 4:
namespace, category, subcategory, action = tid
func_name = '%s_%s_%s' % (category, subcategory.replace('-', '_'), action.replace('-', '_'))
full_action_name = "%s.%s.%s.%s" % (namespace, category, subcategory, action)
func_name = "%s_%s_%s" % (
category,
subcategory.replace("-", "_"),
action.replace("-", "_"),
)
full_action_name = "%s.%s.%s.%s" % (
namespace,
category,
subcategory,
action,
)
else:
assert len(tid) == 3
namespace, category, action = tid
subcategory = None
func_name = '%s_%s' % (category, action.replace('-', '_'))
func_name = "%s_%s" % (category, action.replace("-", "_"))
full_action_name = "%s.%s.%s" % (namespace, category, action)
# Lock the moulinette for the namespace
with MoulinetteLock(namespace, timeout):
start = time()
try:
mod = __import__('%s.%s' % (namespace, category),
globals=globals(), level=0,
fromlist=[func_name])
logger.debug('loading python module %s took %.3fs',
'%s.%s' % (namespace, category), time() - start)
mod = __import__(
"%s.%s" % (namespace, category),
globals=globals(),
level=0,
fromlist=[func_name],
)
logger.debug(
"loading python module %s took %.3fs",
"%s.%s" % (namespace, category),
time() - start,
)
func = getattr(mod, func_name)
except (AttributeError, ImportError):
logger.exception("unable to load function %s.%s",
namespace, func_name)
raise MoulinetteError('error_see_log')
logger.exception("unable to load function %s.%s", namespace, func_name)
raise MoulinetteError("error_see_log")
else:
log_id = start_action_logging()
if logger.isEnabledFor(logging.DEBUG):
# Log arguments in debug mode only for safety reasons
logger.info('processing action [%s]: %s with args=%s',
log_id, full_action_name, arguments)
logger.info(
"processing action [%s]: %s with args=%s",
log_id,
full_action_name,
arguments,
)
else:
logger.info('processing action [%s]: %s',
log_id, full_action_name)
logger.info("processing action [%s]: %s", log_id, full_action_name)
# Load translation and process the action
m18n.load_namespace(namespace)
@ -542,8 +583,7 @@ class ActionsMap(object):
return func(**arguments)
finally:
stop = time()
logger.debug('action [%s] executed in %.3fs',
log_id, stop - start)
logger.debug("action [%s] executed in %.3fs", log_id, stop - start)
@staticmethod
def get_namespaces():
@ -557,10 +597,10 @@ class ActionsMap(object):
namespaces = []
moulinette_env = init_moulinette_env()
DATA_DIR = moulinette_env['DATA_DIR']
DATA_DIR = moulinette_env["DATA_DIR"]
for f in os.listdir('%s/actionsmap' % DATA_DIR):
if f.endswith('.yml'):
for f in os.listdir("%s/actionsmap" % DATA_DIR):
if f.endswith(".yml"):
namespaces.append(f[:-4])
return namespaces
@ -577,8 +617,8 @@ class ActionsMap(object):
"""
moulinette_env = init_moulinette_env()
CACHE_DIR = moulinette_env['CACHE_DIR']
DATA_DIR = moulinette_env['DATA_DIR']
CACHE_DIR = moulinette_env["CACHE_DIR"]
DATA_DIR = moulinette_env["DATA_DIR"]
actionsmaps = {}
if not namespaces:
@ -589,23 +629,23 @@ class ActionsMap(object):
logger.debug("generating cache for actions map namespace '%s'", n)
# Read actions map from yaml file
am_file = '%s/actionsmap/%s.yml' % (DATA_DIR, n)
with open(am_file, 'r') as f:
am_file = "%s/actionsmap/%s.yml" % (DATA_DIR, n)
with open(am_file, "r") as f:
actionsmaps[n] = ordered_yaml_load(f)
# at installation, cachedir might not exists
if os.path.exists('%s/actionsmap/' % CACHE_DIR):
if os.path.exists("%s/actionsmap/" % CACHE_DIR):
# clean old cached files
for i in os.listdir('%s/actionsmap/' % CACHE_DIR):
for i in os.listdir("%s/actionsmap/" % CACHE_DIR):
if i.endswith(".pkl"):
os.remove('%s/actionsmap/%s' % (CACHE_DIR, i))
os.remove("%s/actionsmap/%s" % (CACHE_DIR, i))
# Cache actions map into pickle file
am_file_stat = os.stat(am_file)
pkl = '%s-%d-%d.pkl' % (n, am_file_stat.st_size, am_file_stat.st_mtime)
pkl = "%s-%d-%d.pkl" % (n, am_file_stat.st_size, am_file_stat.st_mtime)
with open_cachefile(pkl, 'w', subdir='actionsmap') as f:
with open_cachefile(pkl, "w", subdir="actionsmap") as f:
pickle.dump(actionsmaps[n], f)
return actionsmaps
@ -646,86 +686,97 @@ class ActionsMap(object):
# * actionsmap is the actual actionsmap that we care about
for namespace, actionsmap in actionsmaps.items():
# Retrieve global parameters
_global = actionsmap.pop('_global', {})
_global = actionsmap.pop("_global", {})
# Set the global configuration to use for the parser.
top_parser.set_global_conf(_global['configuration'])
top_parser.set_global_conf(_global["configuration"])
if top_parser.has_global_parser():
top_parser.add_global_arguments(_global['arguments'])
top_parser.add_global_arguments(_global["arguments"])
# category_name is stuff like "user", "domain", "hooks"...
# category_values is the values of this category (like actions)
for category_name, category_values in actionsmap.items():
if "actions" in category_values:
actions = category_values.pop('actions')
actions = category_values.pop("actions")
else:
actions = {}
if "subcategories" in category_values:
subcategories = category_values.pop('subcategories')
subcategories = category_values.pop("subcategories")
else:
subcategories = {}
# Get category parser
category_parser = top_parser.add_category_parser(category_name,
**category_values)
category_parser = top_parser.add_category_parser(
category_name, **category_values
)
# action_name is like "list" of "domain list"
# action_options are the values
for action_name, action_options in actions.items():
arguments = action_options.pop('arguments', {})
arguments = action_options.pop("arguments", {})
tid = (namespace, category_name, action_name)
# Get action parser
action_parser = category_parser.add_action_parser(action_name,
tid,
**action_options)
action_parser = category_parser.add_action_parser(
action_name, tid, **action_options
)
if action_parser is None: # No parser for the action
continue
# Store action identifier and add arguments
action_parser.set_defaults(_tid=tid)
action_parser.add_arguments(arguments,
extraparser=self.extraparser,
format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra)
action_parser.add_arguments(
arguments,
extraparser=self.extraparser,
format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra,
)
if 'configuration' in action_options:
category_parser.set_conf(tid, action_options['configuration'])
if "configuration" in action_options:
category_parser.set_conf(tid, action_options["configuration"])
# subcategory_name is like "cert" in "domain cert status"
# subcategory_values is the values of this subcategory (like actions)
for subcategory_name, subcategory_values in subcategories.items():
actions = subcategory_values.pop('actions')
actions = subcategory_values.pop("actions")
# Get subcategory parser
subcategory_parser = category_parser.add_subcategory_parser(subcategory_name, **subcategory_values)
subcategory_parser = category_parser.add_subcategory_parser(
subcategory_name, **subcategory_values
)
# action_name is like "status" of "domain cert status"
# action_options are the values
for action_name, action_options in actions.items():
arguments = action_options.pop('arguments', {})
arguments = action_options.pop("arguments", {})
tid = (namespace, category_name, subcategory_name, action_name)
try:
# Get action parser
action_parser = subcategory_parser.add_action_parser(action_name, tid, **action_options)
action_parser = subcategory_parser.add_action_parser(
action_name, tid, **action_options
)
except AttributeError:
# No parser for the action
continue
# Store action identifier and add arguments
action_parser.set_defaults(_tid=tid)
action_parser.add_arguments(arguments,
extraparser=self.extraparser,
format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra)
action_parser.add_arguments(
arguments,
extraparser=self.extraparser,
format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra,
)
if 'configuration' in action_options:
category_parser.set_conf(tid, action_options['configuration'])
if "configuration" in action_options:
category_parser.set_conf(
tid, action_options["configuration"]
)
return top_parser

View file

@ -8,11 +8,12 @@ import hmac
from moulinette.cache import open_cachefile, get_cachedir
from moulinette.core import MoulinetteError
logger = logging.getLogger('moulinette.authenticator')
logger = logging.getLogger("moulinette.authenticator")
# Base Class -----------------------------------------------------------
class BaseAuthenticator(object):
"""Authenticator base representation
@ -59,8 +60,9 @@ class BaseAuthenticator(object):
- password -- A clear text password
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
# Authentication methods
@ -95,9 +97,13 @@ class BaseAuthenticator(object):
except MoulinetteError:
raise
except Exception as e:
logger.exception("authentication (name: '%s', vendor: '%s') fails because '%s'",
self.name, self.vendor, e)
raise MoulinetteError('unable_authenticate')
logger.exception(
"authentication (name: '%s', vendor: '%s') fails because '%s'",
self.name,
self.vendor,
e,
)
raise MoulinetteError("unable_authenticate")
self.is_authenticated = True
@ -108,6 +114,7 @@ class BaseAuthenticator(object):
self._store_session(s_id, s_token)
except Exception as e:
import traceback
traceback.print_exc()
logger.exception("unable to store session because %s", e)
else:
@ -124,9 +131,13 @@ class BaseAuthenticator(object):
except MoulinetteError as e:
raise
except Exception as e:
logger.exception("authentication (name: '%s', vendor: '%s') fails because '%s'",
self.name, self.vendor, e)
raise MoulinetteError('unable_authenticate')
logger.exception(
"authentication (name: '%s', vendor: '%s') fails because '%s'",
self.name,
self.vendor,
e,
)
raise MoulinetteError("unable_authenticate")
else:
self.is_authenticated = True
@ -134,16 +145,17 @@ class BaseAuthenticator(object):
# No credentials given, can't authenticate
#
else:
raise MoulinetteError('unable_authenticate')
raise MoulinetteError("unable_authenticate")
return self
# Private methods
def _open_sessionfile(self, session_id, mode='r'):
def _open_sessionfile(self, session_id, mode="r"):
"""Open a session file for this instance in given mode"""
return open_cachefile('%s.asc' % session_id, mode,
subdir='session/%s' % self.name)
return open_cachefile(
"%s.asc" % session_id, mode, subdir="session/%s" % self.name
)
def _store_session(self, session_id, session_token):
"""Store a session to be able to use it later to reauthenticate"""
@ -151,7 +163,7 @@ class BaseAuthenticator(object):
# We store a hash of the session_id and the session_token (the token is assumed to be secret)
to_hash = "{id}:{token}".format(id=session_id, token=session_token)
hash_ = hashlib.sha256(to_hash).hexdigest()
with self._open_sessionfile(session_id, 'w') as f:
with self._open_sessionfile(session_id, "w") as f:
f.write(hash_)
def _authenticate_session(self, session_id, session_token):
@ -160,11 +172,11 @@ class BaseAuthenticator(object):
# FIXME : shouldn't we also add a check that this session file
# is not too old ? e.g. not older than 24 hours ? idk...
with self._open_sessionfile(session_id, 'r') as f:
with self._open_sessionfile(session_id, "r") as f:
stored_hash = f.read()
except IOError as e:
logger.debug("unable to retrieve session", exc_info=1)
raise MoulinetteError('unable_retrieve_session', exception=e)
raise MoulinetteError("unable_retrieve_session", exception=e)
else:
#
# session_id (or just id) : This is unique id for the current session from the user. Not too important
@ -186,7 +198,7 @@ class BaseAuthenticator(object):
hash_ = hashlib.sha256(to_hash).hexdigest()
if not hmac.compare_digest(hash_, stored_hash):
raise MoulinetteError('invalid_token')
raise MoulinetteError("invalid_token")
else:
return
@ -198,9 +210,9 @@ class BaseAuthenticator(object):
Keyword arguments:
- session_id -- The session id to clean
"""
sessiondir = get_cachedir('session')
sessiondir = get_cachedir("session")
try:
os.remove(os.path.join(sessiondir, self.name, '%s.asc' % session_id))
os.remove(os.path.join(sessiondir, self.name, "%s.asc" % session_id))
except OSError:
pass

View file

@ -4,7 +4,7 @@ import logging
from moulinette.core import MoulinetteError
from moulinette.authenticators import BaseAuthenticator
logger = logging.getLogger('moulinette.authenticator.dummy')
logger = logging.getLogger("moulinette.authenticator.dummy")
# Dummy authenticator implementation
@ -14,7 +14,7 @@ class Authenticator(BaseAuthenticator):
"""Dummy authenticator used for tests
"""
vendor = 'dummy'
vendor = "dummy"
def __init__(self, name, vendor, parameters, extra):
logger.debug("initialize authenticator dummy")

View file

@ -13,11 +13,12 @@ import ldap.modlist as modlist
from moulinette.core import MoulinetteError
from moulinette.authenticators import BaseAuthenticator
logger = logging.getLogger('moulinette.authenticator.ldap')
logger = logging.getLogger("moulinette.authenticator.ldap")
# LDAP Class Implementation --------------------------------------------
class Authenticator(BaseAuthenticator):
"""LDAP Authenticator
@ -38,12 +39,18 @@ class Authenticator(BaseAuthenticator):
self.basedn = parameters["base_dn"]
self.userdn = parameters["user_rdn"]
self.extra = extra
logger.debug("initialize authenticator '%s' with: uri='%s', "
"base_dn='%s', user_rdn='%s'", name, self.uri, self.basedn, self.userdn)
logger.debug(
"initialize authenticator '%s' with: uri='%s', "
"base_dn='%s', user_rdn='%s'",
name,
self.uri,
self.basedn,
self.userdn,
)
super(Authenticator, self).__init__(name)
if self.userdn:
if 'cn=external,cn=auth' in self.userdn:
if "cn=external,cn=auth" in self.userdn:
self.authenticate(None)
else:
self.con = None
@ -55,25 +62,27 @@ class Authenticator(BaseAuthenticator):
# Implement virtual properties
vendor = 'ldap'
vendor = "ldap"
# Implement virtual methods
def authenticate(self, password):
try:
con = ldap.ldapobject.ReconnectLDAPObject(self.uri, retry_max=10, retry_delay=0.5)
con = ldap.ldapobject.ReconnectLDAPObject(
self.uri, retry_max=10, retry_delay=0.5
)
if self.userdn:
if 'cn=external,cn=auth' in self.userdn:
con.sasl_non_interactive_bind_s('EXTERNAL')
if "cn=external,cn=auth" in self.userdn:
con.sasl_non_interactive_bind_s("EXTERNAL")
else:
con.simple_bind_s(self.userdn, password)
else:
con.simple_bind_s()
except ldap.INVALID_CREDENTIALS:
raise MoulinetteError('invalid_password')
raise MoulinetteError("invalid_password")
except ldap.SERVER_DOWN:
logger.exception('unable to reach the server to authenticate')
raise MoulinetteError('ldap_server_down')
logger.exception("unable to reach the server to authenticate")
raise MoulinetteError("ldap_server_down")
# Check that we are indeed logged in with the right identity
try:
@ -91,13 +100,16 @@ class Authenticator(BaseAuthenticator):
def _ensure_password_uses_strong_hash(self, password):
# XXX this has been copy pasted from YunoHost, should we put that into moulinette?
def _hash_user_password(password):
char_set = string.ascii_uppercase + string.ascii_lowercase + string.digits + "./"
salt = ''.join([random.SystemRandom().choice(char_set) for x in range(16)])
salt = '$6$' + salt + '$'
return '{CRYPT}' + crypt.crypt(str(password), salt)
char_set = (
string.ascii_uppercase + string.ascii_lowercase + string.digits + "./"
)
salt = "".join([random.SystemRandom().choice(char_set) for x in range(16)])
salt = "$6$" + salt + "$"
return "{CRYPT}" + crypt.crypt(str(password), salt)
hashed_password = self.search("cn=admin,dc=yunohost,dc=org",
attrs=["userPassword"])[0]
hashed_password = self.search(
"cn=admin,dc=yunohost,dc=org", attrs=["userPassword"]
)[0]
# post-install situation, password is not already set
if "userPassword" not in hashed_password or not hashed_password["userPassword"]:
@ -105,14 +117,12 @@ class Authenticator(BaseAuthenticator):
# we aren't using sha-512 but something else that is weaker, proceed to upgrade
if not hashed_password["userPassword"][0].startswith("{CRYPT}$6$"):
self.update("cn=admin", {
"userPassword": _hash_user_password(password),
})
self.update("cn=admin", {"userPassword": _hash_user_password(password),})
# Additional LDAP methods
# TODO: Review these methods
def search(self, base=None, filter='(objectClass=*)', attrs=['dn']):
def search(self, base=None, filter="(objectClass=*)", attrs=["dn"]):
"""Search in LDAP base
Perform an LDAP search operation with given arguments and return
@ -133,16 +143,22 @@ class Authenticator(BaseAuthenticator):
try:
result = self.con.search_s(base, ldap.SCOPE_SUBTREE, filter, attrs)
except Exception as e:
logger.exception("error during LDAP search operation with: base='%s', "
"filter='%s', attrs=%s and exception %s", base, filter, attrs, e)
raise MoulinetteError('ldap_operation_error')
logger.exception(
"error during LDAP search operation with: base='%s', "
"filter='%s', attrs=%s and exception %s",
base,
filter,
attrs,
e,
)
raise MoulinetteError("ldap_operation_error")
result_list = []
if not attrs or 'dn' not in attrs:
if not attrs or "dn" not in attrs:
result_list = [entry for dn, entry in result]
else:
for dn, entry in result:
entry['dn'] = [dn]
entry["dn"] = [dn]
result_list.append(entry)
return result_list
@ -158,15 +174,20 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError
"""
dn = rdn + ',' + self.basedn
dn = rdn + "," + self.basedn
ldif = modlist.addModlist(attr_dict)
try:
self.con.add_s(dn, ldif)
except Exception as e:
logger.exception("error during LDAP add operation with: rdn='%s', "
"attr_dict=%s and exception %s", rdn, attr_dict, e)
raise MoulinetteError('ldap_operation_error')
logger.exception(
"error during LDAP add operation with: rdn='%s', "
"attr_dict=%s and exception %s",
rdn,
attr_dict,
e,
)
raise MoulinetteError("ldap_operation_error")
else:
return True
@ -181,12 +202,16 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError
"""
dn = rdn + ',' + self.basedn
dn = rdn + "," + self.basedn
try:
self.con.delete_s(dn)
except Exception as e:
logger.exception("error during LDAP delete operation with: rdn='%s' and exception %s", rdn, e)
raise MoulinetteError('ldap_operation_error')
logger.exception(
"error during LDAP delete operation with: rdn='%s' and exception %s",
rdn,
e,
)
raise MoulinetteError("ldap_operation_error")
else:
return True
@ -203,21 +228,26 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError
"""
dn = rdn + ',' + self.basedn
dn = rdn + "," + self.basedn
actual_entry = self.search(base=dn, attrs=None)
ldif = modlist.modifyModlist(actual_entry[0], attr_dict, ignore_oldexistent=1)
try:
if new_rdn:
self.con.rename_s(dn, new_rdn)
dn = new_rdn + ',' + self.basedn
dn = new_rdn + "," + self.basedn
self.con.modify_ext_s(dn, ldif)
except Exception as e:
logger.exception("error during LDAP update operation with: rdn='%s', "
"attr_dict=%s, new_rdn=%s and exception: %s", rdn, attr_dict,
new_rdn, e)
raise MoulinetteError('ldap_operation_error')
logger.exception(
"error during LDAP update operation with: rdn='%s', "
"attr_dict=%s, new_rdn=%s and exception: %s",
rdn,
attr_dict,
new_rdn,
e,
)
raise MoulinetteError("ldap_operation_error")
else:
return True
@ -234,11 +264,16 @@ class Authenticator(BaseAuthenticator):
"""
attr_found = self.get_conflict(value_dict)
if attr_found:
logger.info("attribute '%s' with value '%s' is not unique",
attr_found[0], attr_found[1])
raise MoulinetteError('ldap_attribute_already_exists',
attribute=attr_found[0],
value=attr_found[1])
logger.info(
"attribute '%s' with value '%s' is not unique",
attr_found[0],
attr_found[1],
)
raise MoulinetteError(
"ldap_attribute_already_exists",
attribute=attr_found[0],
value=attr_found[1],
)
return True
def get_conflict(self, value_dict, base_dn=None):
@ -253,7 +288,7 @@ class Authenticator(BaseAuthenticator):
"""
for attr, value in value_dict.items():
if not self.search(base=base_dn, filter=attr + '=' + value):
if not self.search(base=base_dn, filter=attr + "=" + value):
continue
else:
return (attr, value)

View file

@ -5,7 +5,7 @@ import os
from moulinette.globals import init_moulinette_env
def get_cachedir(subdir='', make_dir=True):
def get_cachedir(subdir="", make_dir=True):
"""Get the path to a cache directory
Return the path to the cache directory from an optional
@ -16,7 +16,7 @@ def get_cachedir(subdir='', make_dir=True):
- make_dir -- False to not make directory if it not exists
"""
CACHE_DIR = init_moulinette_env()['CACHE_DIR']
CACHE_DIR = init_moulinette_env()["CACHE_DIR"]
path = os.path.join(CACHE_DIR, subdir)
@ -25,7 +25,7 @@ def get_cachedir(subdir='', make_dir=True):
return path
def open_cachefile(filename, mode='r', subdir=''):
def open_cachefile(filename, mode="r", subdir=""):
"""Open a cache file and return a stream
Attempt to open in 'mode' the cache file 'filename' from the
@ -39,6 +39,6 @@ def open_cachefile(filename, mode='r', subdir=''):
- **kwargs -- Optional arguments for get_cachedir
"""
cache_dir = get_cachedir(subdir, make_dir=True if mode[0] == 'w' else False)
cache_dir = get_cachedir(subdir, make_dir=True if mode[0] == "w" else False)
file_path = os.path.join(cache_dir, filename)
return open(file_path, mode)

View file

@ -11,7 +11,7 @@ import moulinette
from moulinette.globals import init_moulinette_env
logger = logging.getLogger('moulinette.core')
logger = logging.getLogger("moulinette.core")
def during_unittests_run():
@ -20,6 +20,7 @@ def during_unittests_run():
# Internationalization -------------------------------------------------
class Translator(object):
"""Internationalization class
@ -33,15 +34,16 @@ class Translator(object):
"""
def __init__(self, locale_dir, default_locale='en'):
def __init__(self, locale_dir, default_locale="en"):
self.locale_dir = locale_dir
self.locale = default_locale
self._translations = {}
# Attempt to load default translations
if not self._load_translations(default_locale):
logger.error("unable to load locale '%s' from '%s'",
default_locale, locale_dir)
logger.error(
"unable to load locale '%s' from '%s'", default_locale, locale_dir
)
self.default_locale = default_locale
def get_locales(self):
@ -49,7 +51,7 @@ class Translator(object):
locales = []
for f in os.listdir(self.locale_dir):
if f.endswith('.json'):
if f.endswith(".json"):
# TODO: Validate locale
locales.append(f[:-5])
return locales
@ -69,8 +71,11 @@ class Translator(object):
"""
if locale not in self._translations:
if not self._load_translations(locale):
logger.debug("unable to load locale '%s' from '%s'",
self.default_locale, self.locale_dir)
logger.debug(
"unable to load locale '%s' from '%s'",
self.default_locale,
self.locale_dir,
)
# Revert to default locale
self.locale = self.default_locale
@ -93,11 +98,18 @@ class Translator(object):
failed_to_format = False
if key in self._translations.get(self.locale, {}):
try:
return self._translations[self.locale][key].encode('utf-8').format(*args, **kwargs)
return (
self._translations[self.locale][key]
.encode("utf-8")
.format(*args, **kwargs)
)
except KeyError as e:
unformatted_string = self._translations[self.locale][key].encode('utf-8')
error_message = "Failed to format translated string '%s': '%s' with arguments '%s' and '%s, raising error: %s(%s) (don't panic this is just a warning)" % (
key, unformatted_string, args, kwargs, e.__class__.__name__, e
unformatted_string = self._translations[self.locale][key].encode(
"utf-8"
)
error_message = (
"Failed to format translated string '%s': '%s' with arguments '%s' and '%s, raising error: %s(%s) (don't panic this is just a warning)"
% (key, unformatted_string, args, kwargs, e.__class__.__name__, e)
)
if not during_unittests_run():
@ -107,25 +119,37 @@ class Translator(object):
failed_to_format = True
if failed_to_format or (self.default_locale != self.locale and key in self._translations.get(self.default_locale, {})):
logger.info("untranslated key '%s' for locale '%s'",
key, self.locale)
if failed_to_format or (
self.default_locale != self.locale
and key in self._translations.get(self.default_locale, {})
):
logger.info("untranslated key '%s' for locale '%s'", key, self.locale)
try:
return self._translations[self.default_locale][key].encode('utf-8').format(*args, **kwargs)
return (
self._translations[self.default_locale][key]
.encode("utf-8")
.format(*args, **kwargs)
)
except KeyError as e:
unformatted_string = self._translations[self.default_locale][key].encode('utf-8')
error_message = "Failed to format translatable string '%s': '%s' with arguments '%s' and '%s', raising error: %s(%s) (don't panic this is just a warning)" % (
key, unformatted_string, args, kwargs, e.__class__.__name__, e
unformatted_string = self._translations[self.default_locale][
key
].encode("utf-8")
error_message = (
"Failed to format translatable string '%s': '%s' with arguments '%s' and '%s', raising error: %s(%s) (don't panic this is just a warning)"
% (key, unformatted_string, args, kwargs, e.__class__.__name__, e)
)
if not during_unittests_run():
logger.exception(error_message)
else:
raise Exception(error_message)
return self._translations[self.default_locale][key].encode('utf-8')
return self._translations[self.default_locale][key].encode("utf-8")
error_message = "unable to retrieve string to translate with key '%s' for default locale 'locales/%s.json' file (don't panic this is just a warning)" % (key, self.default_locale)
error_message = (
"unable to retrieve string to translate with key '%s' for default locale 'locales/%s.json' file (don't panic this is just a warning)"
% (key, self.default_locale)
)
if not during_unittests_run():
logger.exception(error_message)
@ -152,8 +176,8 @@ class Translator(object):
return True
try:
with open('%s/%s.json' % (self.locale_dir, locale), 'r') as f:
j = json.load(f, 'utf-8')
with open("%s/%s.json" % (self.locale_dir, locale), "r") as f:
j = json.load(f, "utf-8")
except IOError:
return False
else:
@ -174,12 +198,12 @@ class Moulinette18n(object):
"""
def __init__(self, default_locale='en'):
def __init__(self, default_locale="en"):
self.default_locale = default_locale
self.locale = default_locale
moulinette_env = init_moulinette_env()
self.locales_dir = moulinette_env['LOCALES_DIR']
self.locales_dir = moulinette_env["LOCALES_DIR"]
# Init global translator
self._global = Translator(self.locales_dir, default_locale)
@ -201,8 +225,9 @@ class Moulinette18n(object):
if namespace not in self._namespaces:
# Create new Translator object
lib_dir = init_moulinette_env()["LIB_DIR"]
translator = Translator('%s/%s/locales' % (lib_dir, namespace),
self.default_locale)
translator = Translator(
"%s/%s/locales" % (lib_dir, namespace), self.default_locale
)
translator.set_locale(self.locale)
self._namespaces[namespace] = translator
@ -272,19 +297,19 @@ class MoulinetteSignals(object):
if signal not in self.signals:
logger.error("unknown signal '%s'", signal)
return
setattr(self, '_%s' % signal, handler)
setattr(self, "_%s" % signal, handler)
def clear_handler(self, signal):
"""Clear the handler of a signal"""
if signal not in self.signals:
logger.error("unknown signal '%s'", signal)
return
setattr(self, '_%s' % signal, self._notimplemented)
setattr(self, "_%s" % signal, self._notimplemented)
# Signals definitions
"""The list of available signals"""
signals = {'authenticate', 'prompt', 'display'}
signals = {"authenticate", "prompt", "display"}
def authenticate(self, authenticator):
"""Process the authentication
@ -305,7 +330,7 @@ class MoulinetteSignals(object):
return authenticator
return self._authenticate(authenticator)
def prompt(self, message, is_password=False, confirm=False, color='blue'):
def prompt(self, message, is_password=False, confirm=False, color="blue"):
"""Prompt for a value
Prompt the interface for a parameter value which is a password
@ -326,7 +351,7 @@ class MoulinetteSignals(object):
"""
return self._prompt(message, is_password, confirm, color=color)
def display(self, message, style='info'):
def display(self, message, style="info"):
"""Display a message
Display a message with a given style to the user.
@ -352,6 +377,7 @@ class MoulinetteSignals(object):
# Interfaces & Authenticators management -------------------------------
def init_interface(name, kwargs={}, actionsmap={}):
"""Return a new interface instance
@ -371,10 +397,10 @@ def init_interface(name, kwargs={}, actionsmap={}):
from moulinette.actionsmap import ActionsMap
try:
mod = import_module('moulinette.interfaces.%s' % name)
mod = import_module("moulinette.interfaces.%s" % name)
except ImportError as e:
logger.exception("unable to load interface '%s' : %s", name, e)
raise MoulinetteError('error_see_log')
raise MoulinetteError("error_see_log")
else:
try:
# Retrieve interface classes
@ -382,22 +408,23 @@ def init_interface(name, kwargs={}, actionsmap={}):
interface = mod.Interface
except AttributeError:
logger.exception("unable to retrieve classes of interface '%s'", name)
raise MoulinetteError('error_see_log')
raise MoulinetteError("error_see_log")
# Instantiate or retrieve ActionsMap
if isinstance(actionsmap, dict):
amap = ActionsMap(actionsmap.pop('parser', parser), **actionsmap)
amap = ActionsMap(actionsmap.pop("parser", parser), **actionsmap)
elif isinstance(actionsmap, ActionsMap):
amap = actionsmap
else:
logger.error("invalid actionsmap value %r", actionsmap)
raise MoulinetteError('error_see_log')
raise MoulinetteError("error_see_log")
return interface(amap, **kwargs)
# Moulinette core classes ----------------------------------------------
class MoulinetteError(Exception):
"""Moulinette base exception"""
@ -427,12 +454,12 @@ class MoulinetteLock(object):
"""
def __init__(self, namespace, timeout=None, interval=.5):
def __init__(self, namespace, timeout=None, interval=0.5):
self.namespace = namespace
self.timeout = timeout
self.interval = interval
self._lockfile = '/var/run/moulinette_%s.lock' % namespace
self._lockfile = "/var/run/moulinette_%s.lock" % namespace
self._stale_checked = False
self._locked = False
@ -453,7 +480,7 @@ class MoulinetteLock(object):
# after 15*4 seconds, then 15*4*4 seconds...
warning_treshold = 15
logger.debug('acquiring lock...')
logger.debug("acquiring lock...")
while True:
@ -470,20 +497,24 @@ class MoulinetteLock(object):
# Check locked process still exist and take lock if it doesnt
# FIXME : what do in the context of multiple locks :|
first_lock = lock_pids[0]
if not os.path.exists(os.path.join('/proc', str(first_lock), 'exe')):
logger.debug('stale lock file found')
if not os.path.exists(os.path.join("/proc", str(first_lock), "exe")):
logger.debug("stale lock file found")
self._lock()
break
if self.timeout is not None and (time.time() - start_time) > self.timeout:
raise MoulinetteError('instance_already_running')
raise MoulinetteError("instance_already_running")
# warn the user if it's been too much time since they are waiting
if (time.time() - start_time) > warning_treshold:
if warning_treshold == 15:
logger.warning(moulinette.m18n.g('warn_the_user_about_waiting_lock'))
logger.warning(
moulinette.m18n.g("warn_the_user_about_waiting_lock")
)
else:
logger.warning(moulinette.m18n.g('warn_the_user_about_waiting_lock_again'))
logger.warning(
moulinette.m18n.g("warn_the_user_about_waiting_lock_again")
)
warning_treshold *= 4
# Wait before checking again
@ -492,8 +523,8 @@ class MoulinetteLock(object):
# we have warned the user that we were waiting, for better UX also them
# that we have stop waiting and that the command is processing now
if warning_treshold != 15:
logger.warning(moulinette.m18n.g('warn_the_user_that_lock_is_acquired'))
logger.debug('lock has been acquired')
logger.warning(moulinette.m18n.g("warn_the_user_that_lock_is_acquired"))
logger.debug("lock has been acquired")
self._locked = True
def release(self):
@ -506,16 +537,18 @@ class MoulinetteLock(object):
if os.path.exists(self._lockfile):
os.unlink(self._lockfile)
else:
logger.warning("Uhoh, somehow the lock %s did not exist ..." % self._lockfile)
logger.debug('lock has been released')
logger.warning(
"Uhoh, somehow the lock %s did not exist ..." % self._lockfile
)
logger.debug("lock has been released")
self._locked = False
def _lock(self):
try:
with open(self._lockfile, 'w') as f:
with open(self._lockfile, "w") as f:
f.write(str(os.getpid()))
except IOError:
raise MoulinetteError('root_required')
raise MoulinetteError("root_required")
def _lock_PIDs(self):
@ -523,10 +556,10 @@ class MoulinetteLock(object):
return []
with open(self._lockfile) as f:
lock_pids = f.read().strip().split('\n')
lock_pids = f.read().strip().split("\n")
# Make sure to convert those pids to integers
lock_pids = [int(pid) for pid in lock_pids if pid.strip() != '']
lock_pids = [int(pid) for pid in lock_pids if pid.strip() != ""]
return lock_pids

View file

@ -5,8 +5,10 @@ from os import environ
def init_moulinette_env():
return {
'DATA_DIR': environ.get('MOULINETTE_DATA_DIR', '/usr/share/moulinette'),
'LIB_DIR': environ.get('MOULINETTE_LIB_DIR', '/usr/lib/moulinette'),
'LOCALES_DIR': environ.get('MOULINETTE_LOCALES_DIR', '/usr/share/moulinette/locale'),
'CACHE_DIR': environ.get('MOULINETTE_CACHE_DIR', '/var/cache/moulinette'),
"DATA_DIR": environ.get("MOULINETTE_DATA_DIR", "/usr/share/moulinette"),
"LIB_DIR": environ.get("MOULINETTE_LIB_DIR", "/usr/lib/moulinette"),
"LOCALES_DIR": environ.get(
"MOULINETTE_LOCALES_DIR", "/usr/share/moulinette/locale"
),
"CACHE_DIR": environ.get("MOULINETTE_CACHE_DIR", "/var/cache/moulinette"),
}

View file

@ -9,15 +9,16 @@ from collections import deque, OrderedDict
from moulinette import msettings, m18n
from moulinette.core import MoulinetteError
logger = logging.getLogger('moulinette.interface')
logger = logging.getLogger("moulinette.interface")
GLOBAL_SECTION = '_global'
TO_RETURN_PROP = '_to_return'
CALLBACKS_PROP = '_callbacks'
GLOBAL_SECTION = "_global"
TO_RETURN_PROP = "_to_return"
CALLBACKS_PROP = "_callbacks"
# Base Class -----------------------------------------------------------
class BaseActionsMapParser(object):
"""Actions map's base Parser
@ -37,9 +38,8 @@ class BaseActionsMapParser(object):
if parent:
self._o = parent
else:
logger.debug('initializing base actions map parser for %s',
self.interface)
msettings['interface'] = self.interface
logger.debug("initializing base actions map parser for %s", self.interface)
msettings["interface"] = self.interface
self._o = self
self._global_conf = {}
@ -70,8 +70,9 @@ class BaseActionsMapParser(object):
A list of option strings
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
def has_global_parser(self):
return False
@ -85,8 +86,9 @@ class BaseActionsMapParser(object):
An ArgumentParser based object
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
def add_category_parser(self, name, **kwargs):
"""Add a parser for a category
@ -100,8 +102,9 @@ class BaseActionsMapParser(object):
A BaseParser based object
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
def add_action_parser(self, name, tid, **kwargs):
"""Add a parser for an action
@ -116,8 +119,9 @@ class BaseActionsMapParser(object):
An ArgumentParser based object
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
def auth_required(self, args, **kwargs):
"""Check if authentication is required to run the requested action
@ -129,8 +133,9 @@ class BaseActionsMapParser(object):
False, or the authentication profile required
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
def parse_args(self, args, **kwargs):
"""Parse arguments
@ -145,17 +150,19 @@ class BaseActionsMapParser(object):
The populated namespace
"""
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
# Arguments helpers
def prepare_action_namespace(self, tid, namespace=None):
"""Prepare the namespace for a given action"""
# Validate tid and namespace
if not isinstance(tid, tuple) and \
(namespace is None or not hasattr(namespace, TO_RETURN_PROP)):
raise MoulinetteError('invalid_usage')
if not isinstance(tid, tuple) and (
namespace is None or not hasattr(namespace, TO_RETURN_PROP)
):
raise MoulinetteError("invalid_usage")
elif not tid:
tid = GLOBAL_SECTION
@ -229,52 +236,65 @@ class BaseActionsMapParser(object):
# -- 'authenficate'
try:
ifaces = configuration['authenticate']
ifaces = configuration["authenticate"]
except KeyError:
pass
else:
if ifaces == 'all':
conf['authenticate'] = ifaces
if ifaces == "all":
conf["authenticate"] = ifaces
elif ifaces is False:
conf['authenticate'] = False
conf["authenticate"] = False
elif isinstance(ifaces, list):
# Store only if authentication is needed
conf['authenticate'] = True if self.interface in ifaces else False
conf["authenticate"] = True if self.interface in ifaces else False
else:
logger.error("expecting 'all', 'False' or a list for "
"configuration 'authenticate', got %r", ifaces)
raise MoulinetteError('error_see_log')
logger.error(
"expecting 'all', 'False' or a list for "
"configuration 'authenticate', got %r",
ifaces,
)
raise MoulinetteError("error_see_log")
# -- 'authenticator'
try:
auth = configuration['authenticator']
auth = configuration["authenticator"]
except KeyError:
pass
else:
if not is_global and isinstance(auth, str):
try:
# Store needed authenticator profile
conf['authenticator'] = self.global_conf['authenticator'][auth]
conf["authenticator"] = self.global_conf["authenticator"][auth]
except KeyError:
logger.error("requesting profile '%s' which is undefined in "
"global configuration of 'authenticator'", auth)
raise MoulinetteError('error_see_log')
logger.error(
"requesting profile '%s' which is undefined in "
"global configuration of 'authenticator'",
auth,
)
raise MoulinetteError("error_see_log")
elif is_global and isinstance(auth, dict):
if len(auth) == 0:
logger.warning('no profile defined in global configuration '
"for 'authenticator'")
logger.warning(
"no profile defined in global configuration "
"for 'authenticator'"
)
else:
auths = {}
for auth_name, auth_conf in auth.items():
auths[auth_name] = {'name': auth_name,
'vendor': auth_conf.get('vendor'),
'parameters': auth_conf.get('parameters', {}),
'extra': {'help': auth_conf.get('help', None)}}
conf['authenticator'] = auths
auths[auth_name] = {
"name": auth_name,
"vendor": auth_conf.get("vendor"),
"parameters": auth_conf.get("parameters", {}),
"extra": {"help": auth_conf.get("help", None)},
}
conf["authenticator"] = auths
else:
logger.error("expecting a dict of profile(s) or a profile name "
"for configuration 'authenticator', got %r", auth)
raise MoulinetteError('error_see_log')
logger.error(
"expecting a dict of profile(s) or a profile name "
"for configuration 'authenticator', got %r",
auth,
)
raise MoulinetteError("error_see_log")
return conf
@ -291,55 +311,60 @@ class BaseInterface(object):
- actionsmap -- The ActionsMap instance to connect to
"""
# TODO: Add common interface methods and try to standardize default ones
def __init__(self, actionsmap):
raise NotImplementedError("derived class '%s' must override this method" %
self.__class__.__name__)
raise NotImplementedError(
"derived class '%s' must override this method" % self.__class__.__name__
)
# Argument parser ------------------------------------------------------
class _CallbackAction(argparse.Action):
def __init__(self,
option_strings,
dest,
nargs=0,
callback={},
default=argparse.SUPPRESS,
help=None):
if not callback or 'method' not in callback:
raise ValueError('callback must be provided with at least '
'a method key')
class _CallbackAction(argparse.Action):
def __init__(
self,
option_strings,
dest,
nargs=0,
callback={},
default=argparse.SUPPRESS,
help=None,
):
if not callback or "method" not in callback:
raise ValueError("callback must be provided with at least " "a method key")
super(_CallbackAction, self).__init__(
option_strings=option_strings,
dest=dest,
nargs=nargs,
default=default,
help=help)
self.callback_method = callback.get('method')
self.callback_kwargs = callback.get('kwargs', {})
self.callback_return = callback.get('return', False)
logger.debug("registering new callback action '{0}' to {1}".format(
self.callback_method, option_strings))
help=help,
)
self.callback_method = callback.get("method")
self.callback_kwargs = callback.get("kwargs", {})
self.callback_return = callback.get("return", False)
logger.debug(
"registering new callback action '{0}' to {1}".format(
self.callback_method, option_strings
)
)
@property
def callback(self):
if not hasattr(self, '_callback'):
if not hasattr(self, "_callback"):
self._retrieve_callback()
return self._callback
def _retrieve_callback(self):
# Attempt to retrieve callback method
mod_name, func_name = (self.callback_method).rsplit('.', 1)
mod_name, func_name = (self.callback_method).rsplit(".", 1)
try:
mod = __import__(mod_name, globals=globals(), level=0,
fromlist=[func_name])
mod = __import__(mod_name, globals=globals(), level=0, fromlist=[func_name])
func = getattr(mod, func_name)
except (AttributeError, ImportError):
raise ValueError('unable to import method {0}'.format(
self.callback_method))
raise ValueError("unable to import method {0}".format(self.callback_method))
self._callback = func
def __call__(self, parser, namespace, values, option_string=None):
@ -352,9 +377,11 @@ class _CallbackAction(argparse.Action):
# Execute callback and get returned value
value = self.callback(namespace, values, **self.callback_kwargs)
except:
logger.exception("cannot get value from callback method "
"'{0}'".format(self.callback_method))
raise MoulinetteError('error_see_log')
logger.exception(
"cannot get value from callback method "
"'{0}'".format(self.callback_method)
)
raise MoulinetteError("error_see_log")
else:
if value:
if self.callback_return:
@ -379,23 +406,22 @@ class _ExtendedSubParsersAction(argparse._SubParsersAction):
"""
def __init__(self, *args, **kwargs):
required = kwargs.pop('required', False)
required = kwargs.pop("required", False)
super(_ExtendedSubParsersAction, self).__init__(*args, **kwargs)
self.required = required
self._deprecated_command_map = {}
def add_parser(self, name, type_=None, **kwargs):
deprecated = kwargs.pop('deprecated', False)
deprecated_alias = kwargs.pop('deprecated_alias', [])
deprecated = kwargs.pop("deprecated", False)
deprecated_alias = kwargs.pop("deprecated_alias", [])
if deprecated:
self._deprecated_command_map[name] = None
if 'help' in kwargs:
del kwargs['help']
if "help" in kwargs:
del kwargs["help"]
parser = super(_ExtendedSubParsersAction, self).add_parser(
name, **kwargs)
parser = super(_ExtendedSubParsersAction, self).add_parser(name, **kwargs)
# Append each deprecated command alias name
for command in deprecated_alias:
@ -417,27 +443,34 @@ class _ExtendedSubParsersAction(argparse._SubParsersAction):
else:
# Warn the user about deprecated command
if correct_name is None:
logger.warning(m18n.g('deprecated_command', prog=parser.prog,
command=parser_name))
logger.warning(
m18n.g("deprecated_command", prog=parser.prog, command=parser_name)
)
else:
logger.warning(m18n.g('deprecated_command_alias',
old=parser_name, new=correct_name,
prog=parser.prog))
logger.warning(
m18n.g(
"deprecated_command_alias",
old=parser_name,
new=correct_name,
prog=parser.prog,
)
)
values[0] = correct_name
return super(_ExtendedSubParsersAction, self).__call__(
parser, namespace, values, option_string)
parser, namespace, values, option_string
)
class ExtendedArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ExtendedArgumentParser, self).__init__(formatter_class=PositionalsFirstHelpFormatter,
*args, **kwargs)
super(ExtendedArgumentParser, self).__init__(
formatter_class=PositionalsFirstHelpFormatter, *args, **kwargs
)
# Register additional actions
self.register('action', 'callback', _CallbackAction)
self.register('action', 'parsers', _ExtendedSubParsersAction)
self.register("action", "callback", _CallbackAction)
self.register("action", "parsers", _ExtendedSubParsersAction)
def enqueue_callback(self, namespace, callback, values):
queue = self._get_callbacks_queue(namespace)
@ -465,30 +498,33 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
queue = list()
return queue
def add_arguments(self, arguments, extraparser, format_arg_names=None, validate_extra=True):
def add_arguments(
self, arguments, extraparser, format_arg_names=None, validate_extra=True
):
for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context
names = format_arg_names(str(argument_name),
argument_options.pop('full', None))
names = format_arg_names(
str(argument_name), argument_options.pop("full", None)
)
if "type" in argument_options:
argument_options['type'] = eval(argument_options['type'])
argument_options["type"] = eval(argument_options["type"])
if "extra" in argument_options:
extra = argument_options.pop('extra')
extra = argument_options.pop("extra")
argument_dest = self.add_argument(*names, **argument_options).dest
extraparser.add_argument(self.get_default("_tid"),
argument_dest, extra, validate_extra)
extraparser.add_argument(
self.get_default("_tid"), argument_dest, extra, validate_extra
)
continue
self.add_argument(*names, **argument_options)
def _get_nargs_pattern(self, action):
if action.nargs == argparse.PARSER and not action.required:
return '([-AO]*)'
return "([-AO]*)"
else:
return super(ExtendedArgumentParser, self)._get_nargs_pattern(
action)
return super(ExtendedArgumentParser, self)._get_nargs_pattern(action)
def _get_values(self, action, arg_strings):
if action.nargs == argparse.PARSER and not action.required:
@ -498,8 +534,7 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
else:
value = argparse.SUPPRESS
else:
value = super(ExtendedArgumentParser, self)._get_values(
action, arg_strings)
value = super(ExtendedArgumentParser, self)._get_values(action, arg_strings)
return value
# Adapted from :
@ -508,8 +543,7 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
formatter = self._get_formatter()
# usage
formatter.add_usage(self.usage, self._actions,
self._mutually_exclusive_groups)
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
@ -527,14 +561,30 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
subcategories_subparser = copy.copy(action_group._group_actions[0])
# Filter "action"-type and "subcategory"-type commands
actions_subparser.choices = OrderedDict([(k, v) for k, v in actions_subparser.choices.items() if v.type == "action"])
subcategories_subparser.choices = OrderedDict([(k, v) for k, v in subcategories_subparser.choices.items() if v.type == "subcategory"])
actions_subparser.choices = OrderedDict(
[
(k, v)
for k, v in actions_subparser.choices.items()
if v.type == "action"
]
)
subcategories_subparser.choices = OrderedDict(
[
(k, v)
for k, v in subcategories_subparser.choices.items()
if v.type == "subcategory"
]
)
actions_choices = actions_subparser.choices.keys()
subcategories_choices = subcategories_subparser.choices.keys()
actions_subparser._choices_actions = [c for c in choice_actions if c.dest in actions_choices]
subcategories_subparser._choices_actions = [c for c in choice_actions if c.dest in subcategories_choices]
actions_subparser._choices_actions = [
c for c in choice_actions if c.dest in actions_choices
]
subcategories_subparser._choices_actions = [
c for c in choice_actions if c.dest in subcategories_choices
]
# Display each section (actions and subcategories)
if actions_choices != []:
@ -569,11 +619,10 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
# and fix is inspired from here :
# https://stackoverflow.com/questions/26985650/argparse-do-not-catch-positional-arguments-with-nargs/26986546#26986546
class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
def _format_usage(self, usage, actions, groups, prefix):
if prefix is None:
# TWEAK : not using gettext here...
prefix = 'usage: '
prefix = "usage: "
# if usage is specified, use that
if usage is not None:
@ -581,11 +630,11 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
# if no optionals or positionals are available, usage is just prog
elif usage is None and not actions:
usage = '%(prog)s' % dict(prog=self._prog)
usage = "%(prog)s" % dict(prog=self._prog)
# if optionals and positionals are available, calculate usage
elif usage is None:
prog = '%(prog)s' % dict(prog=self._prog)
prog = "%(prog)s" % dict(prog=self._prog)
# split optionals from positionals
optionals = []
@ -600,20 +649,20 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
format = self._format_actions_usage
# TWEAK here : positionals first
action_usage = format(positionals + optionals, groups)
usage = ' '.join([s for s in [prog, action_usage] if s])
usage = " ".join([s for s in [prog, action_usage] if s])
# wrap the usage parts if it's too long
text_width = self._width - self._current_indent
if len(prefix) + len(usage) > text_width:
# break usage into wrappable parts
part_regexp = r'\(.*?\)+|\[.*?\]+|\S+'
part_regexp = r"\(.*?\)+|\[.*?\]+|\S+"
opt_usage = format(optionals, groups)
pos_usage = format(positionals, groups)
opt_parts = re.findall(part_regexp, opt_usage)
pos_parts = re.findall(part_regexp, pos_usage)
assert ' '.join(opt_parts) == opt_usage
assert ' '.join(pos_parts) == pos_usage
assert " ".join(opt_parts) == opt_usage
assert " ".join(pos_parts) == pos_usage
# helper for wrapping lines
def get_lines(parts, indent, prefix=None):
@ -625,20 +674,20 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
line_len = len(indent) - 1
for part in parts:
if line_len + 1 + len(part) > text_width:
lines.append(indent + ' '.join(line))
lines.append(indent + " ".join(line))
line = []
line_len = len(indent) - 1
line.append(part)
line_len += len(part) + 1
if line:
lines.append(indent + ' '.join(line))
lines.append(indent + " ".join(line))
if prefix is not None:
lines[0] = lines[0][len(indent):]
lines[0] = lines[0][len(indent) :]
return lines
# if prog is short, follow it with optionals or positionals
if len(prefix) + len(prog) <= 0.75 * text_width:
indent = ' ' * (len(prefix) + len(prog) + 1)
indent = " " * (len(prefix) + len(prog) + 1)
# START TWEAK : pos_parts first, then opt_parts
if pos_parts:
lines = get_lines([prog] + pos_parts, indent, prefix)
@ -651,7 +700,7 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
# if prog is long, put it on its own line
else:
indent = ' ' * len(prefix)
indent = " " * len(prefix)
parts = pos_parts + opt_parts
lines = get_lines(parts, indent)
if len(lines) > 1:
@ -662,7 +711,7 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
lines = [prog] + lines
# join lines into usage
usage = '\n'.join(lines)
usage = "\n".join(lines)
# prefix with 'usage:'
return '%s%s\n\n' % (prefix, usage)
return "%s%s\n\n" % (prefix, usage)

View file

@ -16,20 +16,22 @@ from bottle import abort
from moulinette import msignals, m18n, env
from moulinette.core import MoulinetteError
from moulinette.interfaces import (
BaseActionsMapParser, BaseInterface, ExtendedArgumentParser,
BaseActionsMapParser,
BaseInterface,
ExtendedArgumentParser,
)
from moulinette.utils import log
from moulinette.utils.serialize import JSONExtendedEncoder
from moulinette.utils.text import random_ascii
logger = log.getLogger('moulinette.interface.api')
logger = log.getLogger("moulinette.interface.api")
# API helpers ----------------------------------------------------------
CSRF_TYPES = set(["text/plain",
"application/x-www-form-urlencoded",
"multipart/form-data"])
CSRF_TYPES = set(
["text/plain", "application/x-www-form-urlencoded", "multipart/form-data"]
)
def is_csrf():
@ -39,7 +41,7 @@ def is_csrf():
return False
if request.content_type is None:
return True
content_type = request.content_type.lower().split(';')[0]
content_type = request.content_type.lower().split(";")[0]
if content_type not in CSRF_TYPES:
return False
@ -53,12 +55,14 @@ def filter_csrf(callback):
abort(403, "CSRF protection")
else:
return callback(*args, **kwargs)
return wrapper
class LogQueues(dict):
"""Map of session id to queue."""
pass
@ -74,7 +78,7 @@ class APIQueueHandler(logging.Handler):
self.queues = LogQueues()
def emit(self, record):
sid = request.get_cookie('session.id')
sid = request.get_cookie("session.id")
try:
queue = self.queues[sid]
except KeyError:
@ -99,13 +103,13 @@ class _HTTPArgumentParser(object):
def __init__(self):
# Initialize the ArgumentParser object
self._parser = ExtendedArgumentParser(usage='',
prefix_chars='@',
add_help=False)
self._parser = ExtendedArgumentParser(
usage="", prefix_chars="@", add_help=False
)
self._parser.error = self._error
self._positional = [] # list(arg_name)
self._optional = {} # dict({arg_name: option_strings})
self._positional = [] # list(arg_name)
self._optional = {} # dict({arg_name: option_strings})
def set_defaults(self, **kwargs):
return self._parser.set_defaults(**kwargs)
@ -113,20 +117,24 @@ class _HTTPArgumentParser(object):
def get_default(self, dest):
return self._parser.get_default(dest)
def add_arguments(self, arguments, extraparser, format_arg_names=None, validate_extra=True):
def add_arguments(
self, arguments, extraparser, format_arg_names=None, validate_extra=True
):
for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context
names = format_arg_names(str(argument_name),
argument_options.pop('full', None))
names = format_arg_names(
str(argument_name), argument_options.pop("full", None)
)
if "type" in argument_options:
argument_options['type'] = eval(argument_options['type'])
argument_options["type"] = eval(argument_options["type"])
if "extra" in argument_options:
extra = argument_options.pop('extra')
extra = argument_options.pop("extra")
argument_dest = self.add_argument(*names, **argument_options).dest
extraparser.add_argument(self.get_default("_tid"),
argument_dest, extra, validate_extra)
extraparser.add_argument(
self.get_default("_tid"), argument_dest, extra, validate_extra
)
continue
self.add_argument(*names, **argument_options)
@ -166,12 +174,19 @@ class _HTTPArgumentParser(object):
if isinstance(v, str):
arg_strings.append(v)
else:
logger.warning("unsupported argument value type %r "
"in %s for option string %s", v, value,
option_string)
logger.warning(
"unsupported argument value type %r "
"in %s for option string %s",
v,
value,
option_string,
)
else:
logger.warning("unsupported argument type %r for option "
"string %s", value, option_string)
logger.warning(
"unsupported argument type %r for option " "string %s",
value,
option_string,
)
return arg_strings
@ -208,14 +223,15 @@ class _ActionsMapPlugin(object):
to serve messages coming from the 'display' signal
"""
name = 'actionsmap'
name = "actionsmap"
api = 2
def __init__(self, actionsmap, use_websocket, log_queues={}):
# Connect signals to handlers
msignals.set_handler('authenticate', self._do_authenticate)
msignals.set_handler("authenticate", self._do_authenticate)
if use_websocket:
msignals.set_handler('display', self._do_display)
msignals.set_handler("display", self._do_display)
self.actionsmap = actionsmap
self.use_websocket = use_websocket
@ -237,34 +253,52 @@ class _ActionsMapPlugin(object):
def wrapper():
kwargs = {}
try:
kwargs['password'] = request.POST['password']
kwargs["password"] = request.POST["password"]
except KeyError:
raise HTTPBadRequestResponse("Missing password parameter")
try:
kwargs['profile'] = request.POST['profile']
kwargs["profile"] = request.POST["profile"]
except KeyError:
pass
return callback(**kwargs)
return wrapper
# Logout wrapper
def _logout(callback):
def wrapper():
kwargs = {}
kwargs['profile'] = request.POST.get('profile', "default")
kwargs["profile"] = request.POST.get("profile", "default")
return callback(**kwargs)
return wrapper
# Append authentication routes
app.route('/login', name='login', method='POST',
callback=self.login, skip=['actionsmap'], apply=_login)
app.route('/logout', name='logout', method='GET',
callback=self.logout, skip=['actionsmap'], apply=_logout)
app.route(
"/login",
name="login",
method="POST",
callback=self.login,
skip=["actionsmap"],
apply=_login,
)
app.route(
"/logout",
name="logout",
method="GET",
callback=self.logout,
skip=["actionsmap"],
apply=_logout,
)
# Append messages route
if self.use_websocket:
app.route('/messages', name='messages',
callback=self.messages, skip=['actionsmap'])
app.route(
"/messages",
name="messages",
callback=self.messages,
skip=["actionsmap"],
)
# Append routes from the actions map
for (m, p) in self.actionsmap.parser.routes:
@ -281,6 +315,7 @@ class _ActionsMapPlugin(object):
context -- An instance of Route
"""
def _format(value):
if isinstance(value, list) and len(value) == 1:
return value[0]
@ -311,11 +346,12 @@ class _ActionsMapPlugin(object):
# Process the action
return callback((request.method, context.rule), params)
return wrapper
# Routes callbacks
def login(self, password, profile='default'):
def login(self, password, profile="default"):
"""Log in to an authenticator profile
Attempt to authenticate to a given authenticator profile and
@ -328,14 +364,13 @@ class _ActionsMapPlugin(object):
"""
# Retrieve session values
s_id = request.get_cookie('session.id') or random_ascii()
s_id = request.get_cookie("session.id") or random_ascii()
try:
s_secret = self.secrets[s_id]
except KeyError:
s_tokens = {}
else:
s_tokens = request.get_cookie('session.tokens',
secret=s_secret) or {}
s_tokens = request.get_cookie("session.tokens", secret=s_secret) or {}
s_new_token = random_ascii()
try:
@ -354,10 +389,11 @@ class _ActionsMapPlugin(object):
s_tokens[profile] = s_new_token
self.secrets[s_id] = s_secret = random_ascii()
response.set_cookie('session.id', s_id, secure=True)
response.set_cookie('session.tokens', s_tokens, secure=True,
secret=s_secret)
return m18n.g('logged_in')
response.set_cookie("session.id", s_id, secure=True)
response.set_cookie(
"session.tokens", s_tokens, secure=True, secret=s_secret
)
return m18n.g("logged_in")
def logout(self, profile):
"""Log out from an authenticator profile
@ -369,24 +405,23 @@ class _ActionsMapPlugin(object):
- profile -- The authenticator profile name to log out
"""
s_id = request.get_cookie('session.id')
s_id = request.get_cookie("session.id")
try:
# We check that there's a (signed) session.hash available
# for additional security ?
# (An attacker could not craft such signed hashed ? (FIXME : need to make sure of this))
s_secret = self.secrets[s_id]
request.get_cookie('session.tokens',
secret=s_secret, default={})[profile]
request.get_cookie("session.tokens", secret=s_secret, default={})[profile]
except KeyError:
raise HTTPUnauthorizedResponse(m18n.g('not_logged_in'))
raise HTTPUnauthorizedResponse(m18n.g("not_logged_in"))
else:
del self.secrets[s_id]
authenticator = self.actionsmap.get_authenticator_for_profile(profile)
authenticator._clean_session(s_id)
# TODO: Clean the session for profile only
# Delete cookie and clean the session
response.set_cookie('session.tokens', '', max_age=-1)
return m18n.g('logged_out')
response.set_cookie("session.tokens", "", max_age=-1)
return m18n.g("logged_out")
def messages(self):
"""Listen to the messages WebSocket stream
@ -396,7 +431,7 @@ class _ActionsMapPlugin(object):
dict { style: message }.
"""
s_id = request.get_cookie('session.id')
s_id = request.get_cookie("session.id")
try:
queue = self.log_queues[s_id]
except KeyError:
@ -404,9 +439,9 @@ class _ActionsMapPlugin(object):
queue = Queue()
self.log_queues[s_id] = queue
wsock = request.environ.get('wsgi.websocket')
wsock = request.environ.get("wsgi.websocket")
if not wsock:
raise HTTPErrorResponse(m18n.g('websocket_request_expected'))
raise HTTPErrorResponse(m18n.g("websocket_request_expected"))
while True:
item = queue.get()
@ -447,17 +482,16 @@ class _ActionsMapPlugin(object):
if isinstance(e, HTTPResponse):
raise e
import traceback
tb = traceback.format_exc()
logs = {"route": _route,
"arguments": arguments,
"traceback": tb}
logs = {"route": _route, "arguments": arguments, "traceback": tb}
return HTTPErrorResponse(json_encode(logs))
else:
return format_for_response(ret)
finally:
# Close opened WebSocket by putting StopIteration in the queue
try:
queue = self.log_queues[request.get_cookie('session.id')]
queue = self.log_queues[request.get_cookie("session.id")]
except KeyError:
pass
else:
@ -471,13 +505,14 @@ class _ActionsMapPlugin(object):
Handle the core.MoulinetteSignals.authenticate signal.
"""
s_id = request.get_cookie('session.id')
s_id = request.get_cookie("session.id")
try:
s_secret = self.secrets[s_id]
s_token = request.get_cookie('session.tokens',
secret=s_secret, default={})[authenticator.name]
s_token = request.get_cookie("session.tokens", secret=s_secret, default={})[
authenticator.name
]
except KeyError:
msg = m18n.g('authentication_required')
msg = m18n.g("authentication_required")
raise HTTPUnauthorizedResponse(msg)
else:
return authenticator(token=(s_id, s_token))
@ -488,7 +523,7 @@ class _ActionsMapPlugin(object):
Handle the core.MoulinetteSignals.display signal.
"""
s_id = request.get_cookie('session.id')
s_id = request.get_cookie("session.id")
try:
queue = self.log_queues[s_id]
except KeyError:
@ -504,50 +539,48 @@ class _ActionsMapPlugin(object):
# HTTP Responses -------------------------------------------------------
class HTTPOKResponse(HTTPResponse):
def __init__(self, output=''):
class HTTPOKResponse(HTTPResponse):
def __init__(self, output=""):
super(HTTPOKResponse, self).__init__(output, 200)
class HTTPBadRequestResponse(HTTPResponse):
def __init__(self, output=''):
def __init__(self, output=""):
super(HTTPBadRequestResponse, self).__init__(output, 400)
class HTTPUnauthorizedResponse(HTTPResponse):
def __init__(self, output=''):
def __init__(self, output=""):
super(HTTPUnauthorizedResponse, self).__init__(output, 401)
class HTTPErrorResponse(HTTPResponse):
def __init__(self, output=''):
def __init__(self, output=""):
super(HTTPErrorResponse, self).__init__(output, 500)
def format_for_response(content):
"""Format the resulted content of a request for the HTTP response."""
if request.method == 'POST':
if request.method == "POST":
response.status = 201 # Created
elif request.method == 'GET':
elif request.method == "GET":
response.status = 200 # Ok
else:
# Return empty string if no content
if content is None or len(content) == 0:
response.status = 204 # No Content
return ''
return ""
response.status = 200
# Return JSON-style response
response.content_type = 'application/json'
response.content_type = "application/json"
return json_encode(content, cls=JSONExtendedEncoder)
# API Classes Implementation -------------------------------------------
class ActionsMapParser(BaseActionsMapParser):
"""Actions map's Parser for the API
@ -561,7 +594,7 @@ class ActionsMapParser(BaseActionsMapParser):
super(ActionsMapParser, self).__init__(parent)
self._parsers = {} # dict({(method, path): _HTTPArgumentParser})
self._route_re = re.compile(r'(GET|POST|PUT|DELETE) (/\S+)')
self._route_re = re.compile(r"(GET|POST|PUT|DELETE) (/\S+)")
@property
def routes(self):
@ -570,19 +603,19 @@ class ActionsMapParser(BaseActionsMapParser):
# Implement virtual properties
interface = 'api'
interface = "api"
# Implement virtual methods
@staticmethod
def format_arg_names(name, full):
if name[0] != '-':
if name[0] != "-":
return [name]
if full:
return [full.replace('--', '@', 1)]
if name.startswith('--'):
return [name.replace('--', '@', 1)]
return [name.replace('-', '@', 1)]
return [full.replace("--", "@", 1)]
if name.startswith("--"):
return [name.replace("--", "@", 1)]
return [name.replace("-", "@", 1)]
def add_category_parser(self, name, **kwargs):
return self
@ -611,8 +644,9 @@ class ActionsMapParser(BaseActionsMapParser):
try:
keys.append(self._extract_route(r))
except ValueError as e:
logger.warning("cannot add api route '%s' for "
"action %s: %s", r, tid, e)
logger.warning(
"cannot add api route '%s' for " "action %s: %s", r, tid, e
)
continue
if len(keys) == 0:
raise ValueError("no valid api route found")
@ -631,7 +665,7 @@ class ActionsMapParser(BaseActionsMapParser):
try:
# Retrieve the tid for the route
tid, _ = self._parsers[route]
if not self.get_conf(tid, 'authenticate'):
if not self.get_conf(tid, "authenticate"):
return False
else:
# TODO: In the future, we could make the authentication
@ -640,10 +674,10 @@ class ActionsMapParser(BaseActionsMapParser):
# auth with some custom auth system to access some
# data with something like :
# return self.get_conf(tid, 'authenticator')
return 'default'
return "default"
except KeyError:
logger.error("no argument parser found for route '%s'", route)
raise MoulinetteError('error_see_log')
raise MoulinetteError("error_see_log")
def parse_args(self, args, route, **kwargs):
"""Parse arguments
@ -657,7 +691,7 @@ class ActionsMapParser(BaseActionsMapParser):
_, parser = self._parsers[route]
except KeyError:
logger.error("no argument parser found for route '%s'", route)
raise MoulinetteError('error_see_log')
raise MoulinetteError("error_see_log")
ret = argparse.Namespace()
# TODO: Catch errors?
@ -705,8 +739,7 @@ class Interface(BaseInterface):
"""
def __init__(self, actionsmap, routes={}, use_websocket=True,
log_queues=None):
def __init__(self, actionsmap, routes={}, use_websocket=True, log_queues=None):
self.use_websocket = use_websocket
# Attempt to retrieve log queues from an APIQueueHandler
@ -721,14 +754,15 @@ class Interface(BaseInterface):
# Wrapper which sets proper header
def apiheader(callback):
def wrapper(*args, **kwargs):
response.set_header('Access-Control-Allow-Origin', '*')
response.set_header("Access-Control-Allow-Origin", "*")
return callback(*args, **kwargs)
return wrapper
# Attempt to retrieve and set locale
def api18n(callback):
try:
locale = request.params.pop('locale')
locale = request.params.pop("locale")
except KeyError:
locale = m18n.default_locale
m18n.set_locale(locale)
@ -741,17 +775,17 @@ class Interface(BaseInterface):
app.install(_ActionsMapPlugin(actionsmap, use_websocket, log_queues))
# Append default routes
# app.route(['/api', '/api/<category:re:[a-z]+>'], method='GET',
# callback=self.doc, skip=['actionsmap'])
# app.route(['/api', '/api/<category:re:[a-z]+>'], method='GET',
# callback=self.doc, skip=['actionsmap'])
# Append additional routes
# TODO: Add optional authentication to those routes?
for (m, p), c in routes.items():
app.route(p, method=m, callback=c, skip=['actionsmap'])
app.route(p, method=m, callback=c, skip=["actionsmap"])
self._app = app
def run(self, host='localhost', port=80):
def run(self, host="localhost", port=80):
"""Run the moulinette
Start a server instance on the given port to serve moulinette
@ -762,25 +796,29 @@ class Interface(BaseInterface):
- port -- Server port to bind to
"""
logger.debug("starting the server instance in %s:%d with websocket=%s",
host, port, self.use_websocket)
logger.debug(
"starting the server instance in %s:%d with websocket=%s",
host,
port,
self.use_websocket,
)
try:
if self.use_websocket:
from gevent.pywsgi import WSGIServer
from geventwebsocket.handler import WebSocketHandler
server = WSGIServer((host, port), self._app,
handler_class=WebSocketHandler)
server = WSGIServer(
(host, port), self._app, handler_class=WebSocketHandler
)
server.serve_forever()
else:
run(self._app, host=host, port=port)
except IOError as e:
logger.exception("unable to start the server instance on %s:%d",
host, port)
logger.exception("unable to start the server instance on %s:%d", host, port)
if e.args[0] == errno.EADDRINUSE:
raise MoulinetteError('server_already_running')
raise MoulinetteError('error_see_log')
raise MoulinetteError("server_already_running")
raise MoulinetteError("error_see_log")
# Routes handlers
@ -792,14 +830,14 @@ class Interface(BaseInterface):
category -- Name of the category
"""
DATA_DIR = env()['DATA_DIR']
DATA_DIR = env()["DATA_DIR"]
if category is None:
with open('%s/../doc/resources.json' % DATA_DIR) as f:
with open("%s/../doc/resources.json" % DATA_DIR) as f:
return f.read()
try:
with open('%s/../doc/%s.json' % (DATA_DIR, category)) as f:
with open("%s/../doc/%s.json" % (DATA_DIR, category)) as f:
return f.read()
except IOError:
return None

View file

@ -15,27 +15,29 @@ import argcomplete
from moulinette import msignals, m18n
from moulinette.core import MoulinetteError
from moulinette.interfaces import (
BaseActionsMapParser, BaseInterface, ExtendedArgumentParser,
BaseActionsMapParser,
BaseInterface,
ExtendedArgumentParser,
)
from moulinette.utils import log
logger = log.getLogger('moulinette.cli')
logger = log.getLogger("moulinette.cli")
# CLI helpers ----------------------------------------------------------
CLI_COLOR_TEMPLATE = '\033[{:d}m\033[1m'
END_CLI_COLOR = '\033[m'
CLI_COLOR_TEMPLATE = "\033[{:d}m\033[1m"
END_CLI_COLOR = "\033[m"
colors_codes = {
'red': CLI_COLOR_TEMPLATE.format(31),
'green': CLI_COLOR_TEMPLATE.format(32),
'yellow': CLI_COLOR_TEMPLATE.format(33),
'blue': CLI_COLOR_TEMPLATE.format(34),
'purple': CLI_COLOR_TEMPLATE.format(35),
'cyan': CLI_COLOR_TEMPLATE.format(36),
'white': CLI_COLOR_TEMPLATE.format(37),
"red": CLI_COLOR_TEMPLATE.format(31),
"green": CLI_COLOR_TEMPLATE.format(32),
"yellow": CLI_COLOR_TEMPLATE.format(33),
"blue": CLI_COLOR_TEMPLATE.format(34),
"purple": CLI_COLOR_TEMPLATE.format(35),
"cyan": CLI_COLOR_TEMPLATE.format(36),
"white": CLI_COLOR_TEMPLATE.format(37),
}
@ -50,7 +52,7 @@ def colorize(astr, color):
"""
if os.isatty(1):
return '{:s}{:s}{:s}'.format(colors_codes[color], astr, END_CLI_COLOR)
return "{:s}{:s}{:s}".format(colors_codes[color], astr, END_CLI_COLOR)
else:
return astr
@ -91,7 +93,7 @@ def plain_print_dict(d, depth=0):
plain_print_dict(v, depth + 1)
else:
if isinstance(d, unicode):
d = d.encode('utf-8')
d = d.encode("utf-8")
print(d)
@ -107,7 +109,7 @@ def pretty_date(_date):
nowtz = nowtz.replace(tzinfo=pytz.utc)
offsetHour = nowutc - nowtz
offsetHour = int(round(offsetHour.total_seconds() / 3600))
localtz = 'Etc/GMT%+d' % offsetHour
localtz = "Etc/GMT%+d" % offsetHour
# Transform naive date into UTC date
if _date.tzinfo is None:
@ -136,7 +138,7 @@ def pretty_print_dict(d, depth=0):
keys = sorted(keys)
for k in keys:
v = d[k]
k = colorize(str(k), 'purple')
k = colorize(str(k), "purple")
if isinstance(v, (tuple, set)):
v = list(v)
if isinstance(v, list) and len(v) == 1:
@ -153,13 +155,13 @@ def pretty_print_dict(d, depth=0):
pretty_print_dict({key: value}, depth + 1)
else:
if isinstance(value, unicode):
value = value.encode('utf-8')
value = value.encode("utf-8")
elif isinstance(v, date):
v = pretty_date(v)
print("{:s}- {}".format(" " * (depth + 1), value))
else:
if isinstance(v, unicode):
v = v.encode('utf-8')
v = v.encode("utf-8")
elif isinstance(v, date):
v = pretty_date(v)
print("{:s}{}: {}".format(" " * depth, k, v))
@ -169,12 +171,13 @@ def get_locale():
"""Return current user locale"""
lang = locale.getdefaultlocale()[0]
if not lang:
return ''
return ""
return lang[:2]
# CLI Classes Implementation -------------------------------------------
class TTYHandler(logging.StreamHandler):
"""TTY log handler
@ -192,17 +195,18 @@ class TTYHandler(logging.StreamHandler):
stderr. Otherwise, they are sent to stdout.
"""
LEVELS_COLOR = {
log.NOTSET: 'white',
log.DEBUG: 'white',
log.INFO: 'cyan',
log.SUCCESS: 'green',
log.WARNING: 'yellow',
log.ERROR: 'red',
log.CRITICAL: 'red',
log.NOTSET: "white",
log.DEBUG: "white",
log.INFO: "cyan",
log.SUCCESS: "green",
log.WARNING: "yellow",
log.ERROR: "red",
log.CRITICAL: "red",
}
def __init__(self, message_key='fmessage'):
def __init__(self, message_key="fmessage"):
logging.StreamHandler.__init__(self)
self.message_key = message_key
@ -210,16 +214,15 @@ class TTYHandler(logging.StreamHandler):
"""Enhance message with level and colors if supported."""
msg = record.getMessage()
if self.supports_color():
level = ''
level = ""
if self.level <= log.DEBUG:
# add level name before message
level = '%s ' % record.levelname
elif record.levelname in ['SUCCESS', 'WARNING', 'ERROR', 'INFO']:
level = "%s " % record.levelname
elif record.levelname in ["SUCCESS", "WARNING", "ERROR", "INFO"]:
# add translated level name before message
level = '%s ' % m18n.g(record.levelname.lower())
color = self.LEVELS_COLOR.get(record.levelno, 'white')
msg = '{0}{1}{2}{3}'.format(
colors_codes[color], level, END_CLI_COLOR, msg)
level = "%s " % m18n.g(record.levelname.lower())
color = self.LEVELS_COLOR.get(record.levelno, "white")
msg = "{0}{1}{2}{3}".format(colors_codes[color], level, END_CLI_COLOR, msg)
if self.formatter:
# use user-defined formatter
record.__dict__[self.message_key] = msg
@ -236,7 +239,7 @@ class TTYHandler(logging.StreamHandler):
def supports_color(self):
"""Check whether current stream supports color."""
if hasattr(self.stream, 'isatty') and self.stream.isatty():
if hasattr(self.stream, "isatty") and self.stream.isatty():
return True
return False
@ -256,12 +259,13 @@ class ActionsMapParser(BaseActionsMapParser):
"""
def __init__(self, parent=None, parser=None, subparser_kwargs=None,
top_parser=None, **kwargs):
def __init__(
self, parent=None, parser=None, subparser_kwargs=None, top_parser=None, **kwargs
):
super(ActionsMapParser, self).__init__(parent)
if subparser_kwargs is None:
subparser_kwargs = {'title': "categories", 'required': False}
subparser_kwargs = {"title": "categories", "required": False}
self._parser = parser or ExtendedArgumentParser()
self._subparsers = self._parser.add_subparsers(**subparser_kwargs)
@ -277,13 +281,13 @@ class ActionsMapParser(BaseActionsMapParser):
# Implement virtual properties
interface = 'cli'
interface = "cli"
# Implement virtual methods
@staticmethod
def format_arg_names(name, full):
if name[0] == '-' and full:
if name[0] == "-" and full:
return [name, full]
return [name]
@ -300,13 +304,10 @@ class ActionsMapParser(BaseActionsMapParser):
A new ActionsMapParser object for the category
"""
parser = self._subparsers.add_parser(name,
description=category_help,
help=category_help,
**kwargs)
return self.__class__(self, parser, {
'title': "subcommands", 'required': True
})
parser = self._subparsers.add_parser(
name, description=category_help, help=category_help, **kwargs
)
return self.__class__(self, parser, {"title": "subcommands", "required": True})
def add_subcategory_parser(self, name, subcategory_help=None, **kwargs):
"""Add a parser for a subcategory
@ -318,17 +319,24 @@ class ActionsMapParser(BaseActionsMapParser):
A new ActionsMapParser object for the category
"""
parser = self._subparsers.add_parser(name,
type_="subcategory",
description=subcategory_help,
help=subcategory_help,
**kwargs)
return self.__class__(self, parser, {
'title': "actions", 'required': True
})
parser = self._subparsers.add_parser(
name,
type_="subcategory",
description=subcategory_help,
help=subcategory_help,
**kwargs
)
return self.__class__(self, parser, {"title": "actions", "required": True})
def add_action_parser(self, name, tid, action_help=None, deprecated=False,
deprecated_alias=[], **kwargs):
def add_action_parser(
self,
name,
tid,
action_help=None,
deprecated=False,
deprecated_alias=[],
**kwargs
):
"""Add a parser for an action
Keyword arguments:
@ -340,18 +348,21 @@ class ActionsMapParser(BaseActionsMapParser):
A new ExtendedArgumentParser object for the action
"""
return self._subparsers.add_parser(name,
type_="action",
help=action_help,
description=action_help,
deprecated=deprecated,
deprecated_alias=deprecated_alias)
return self._subparsers.add_parser(
name,
type_="action",
help=action_help,
description=action_help,
deprecated=deprecated,
deprecated_alias=deprecated_alias,
)
def add_global_arguments(self, arguments):
for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context
names = self.format_arg_names(str(argument_name),
argument_options.pop('full', None))
names = self.format_arg_names(
str(argument_name), argument_options.pop("full", None)
)
self.global_parser.add_argument(*names, **argument_options)
@ -363,12 +374,12 @@ class ActionsMapParser(BaseActionsMapParser):
except SystemExit:
raise
except:
logger.exception("unable to parse arguments '%s'", ' '.join(args))
raise MoulinetteError('error_see_log')
logger.exception("unable to parse arguments '%s'", " ".join(args))
raise MoulinetteError("error_see_log")
tid = getattr(ret, '_tid', None)
if self.get_conf(tid, 'authenticate'):
return self.get_conf(tid, 'authenticator')
tid = getattr(ret, "_tid", None)
if self.get_conf(tid, "authenticate"):
return self.get_conf(tid, "authenticator")
else:
return False
@ -378,10 +389,10 @@ class ActionsMapParser(BaseActionsMapParser):
except SystemExit:
raise
except:
logger.exception("unable to parse arguments '%s'", ' '.join(args))
raise MoulinetteError('error_see_log')
logger.exception("unable to parse arguments '%s'", " ".join(args))
raise MoulinetteError("error_see_log")
else:
self.prepare_action_namespace(getattr(ret, '_tid', None), ret)
self.prepare_action_namespace(getattr(ret, "_tid", None), ret)
self._parser.dequeue_callbacks(ret)
return ret
@ -403,10 +414,10 @@ class Interface(BaseInterface):
m18n.set_locale(get_locale())
# Connect signals to handlers
msignals.set_handler('display', self._do_display)
msignals.set_handler("display", self._do_display)
if os.isatty(1):
msignals.set_handler('authenticate', self._do_authenticate)
msignals.set_handler('prompt', self._do_prompt)
msignals.set_handler("authenticate", self._do_authenticate)
msignals.set_handler("prompt", self._do_prompt)
self.actionsmap = actionsmap
@ -426,30 +437,30 @@ class Interface(BaseInterface):
- timeout -- Number of seconds before this command will timeout because it can't acquire the lock (meaning that another command is currently running), by default there is no timeout and the command will wait until it can get the lock
"""
if output_as and output_as not in ['json', 'plain', 'none']:
raise MoulinetteError('invalid_usage')
if output_as and output_as not in ["json", "plain", "none"]:
raise MoulinetteError("invalid_usage")
# auto-complete
argcomplete.autocomplete(self.actionsmap.parser._parser)
# Set handler for authentication
if password:
msignals.set_handler('authenticate',
lambda a: a(password=password))
msignals.set_handler("authenticate", lambda a: a(password=password))
try:
ret = self.actionsmap.process(args, timeout=timeout)
except (KeyboardInterrupt, EOFError):
raise MoulinetteError('operation_interrupted')
raise MoulinetteError("operation_interrupted")
if ret is None or output_as == 'none':
if ret is None or output_as == "none":
return
# Format and print result
if output_as:
if output_as == 'json':
if output_as == "json":
import json
from moulinette.utils.serialize import JSONExtendedEncoder
print(json.dumps(ret, cls=JSONExtendedEncoder))
else:
plain_print_dict(ret)
@ -468,11 +479,10 @@ class Interface(BaseInterface):
"""
# TODO: Allow token authentication?
help = authenticator.extra.get("help")
msg = m18n.n(help) if help else m18n.g('password')
return authenticator(password=self._do_prompt(msg, True, False,
color='yellow'))
msg = m18n.n(help) if help else m18n.g("password")
return authenticator(password=self._do_prompt(msg, True, False, color="yellow"))
def _do_prompt(self, message, is_password, confirm, color='blue'):
def _do_prompt(self, message, is_password, confirm, color="blue"):
"""Prompt for a value
Handle the core.MoulinetteSignals.prompt signal.
@ -482,16 +492,15 @@ class Interface(BaseInterface):
"""
if is_password:
prompt = lambda m: getpass.getpass(colorize(m18n.g('colon', m),
color))
prompt = lambda m: getpass.getpass(colorize(m18n.g("colon", m), color))
else:
prompt = lambda m: raw_input(colorize(m18n.g('colon', m), color))
prompt = lambda m: raw_input(colorize(m18n.g("colon", m), color))
value = prompt(message)
if confirm:
m = message[0].lower() + message[1:]
if prompt(m18n.g('confirm', prompt=m)) != value:
raise MoulinetteError('values_mismatch')
if prompt(m18n.g("confirm", prompt=m)) != value:
raise MoulinetteError("values_mismatch")
return value
@ -502,12 +511,12 @@ class Interface(BaseInterface):
"""
if isinstance(message, unicode):
message = message.encode('utf-8')
if style == 'success':
print('{} {}'.format(colorize(m18n.g('success'), 'green'), message))
elif style == 'warning':
print('{} {}'.format(colorize(m18n.g('warning'), 'yellow'), message))
elif style == 'error':
print('{} {}'.format(colorize(m18n.g('error'), 'red'), message))
message = message.encode("utf-8")
if style == "success":
print("{} {}".format(colorize(m18n.g("success"), "green"), message))
elif style == "warning":
print("{} {}".format(colorize(m18n.g("warning"), "yellow"), message))
elif style == "error":
print("{} {}".format(colorize(m18n.g("error"), "red"), message))
else:
print(message)

View file

@ -22,21 +22,25 @@ def read_file(file_path):
Keyword argument:
file_path -- Path to the text file
"""
assert isinstance(file_path, basestring), "Error: file_path '%s' should be a string but is of type '%s' instead" % (file_path, type(file_path))
assert isinstance(file_path, basestring), (
"Error: file_path '%s' should be a string but is of type '%s' instead"
% (file_path, type(file_path))
)
# Check file exists
if not os.path.isfile(file_path):
raise MoulinetteError('file_not_exist', path=file_path)
raise MoulinetteError("file_not_exist", path=file_path)
# Open file and read content
try:
with open(file_path, "r") as f:
file_content = f.read()
except IOError as e:
raise MoulinetteError('cannot_open_file', file=file_path, error=str(e))
raise MoulinetteError("cannot_open_file", file=file_path, error=str(e))
except Exception:
raise MoulinetteError('unknown_error_reading_file',
file=file_path, error=str(e))
raise MoulinetteError(
"unknown_error_reading_file", file=file_path, error=str(e)
)
return file_content
@ -56,7 +60,7 @@ def read_json(file_path):
try:
loaded_json = json.loads(file_content)
except ValueError as e:
raise MoulinetteError('corrupted_json', ressource=file_path, error=str(e))
raise MoulinetteError("corrupted_json", ressource=file_path, error=str(e))
return loaded_json
@ -76,7 +80,7 @@ def read_yaml(file_path):
try:
loaded_yaml = yaml.safe_load(file_content)
except Exception as e:
raise MoulinetteError('corrupted_yaml', ressource=file_path, error=str(e))
raise MoulinetteError("corrupted_yaml", ressource=file_path, error=str(e))
return loaded_yaml
@ -96,9 +100,9 @@ def read_toml(file_path):
try:
loaded_toml = toml.loads(file_content, _dict=OrderedDict)
except Exception as e:
raise MoulinetteError(errno.EINVAL,
m18n.g('corrupted_toml',
ressource=file_path, error=str(e)))
raise MoulinetteError(
errno.EINVAL, m18n.g("corrupted_toml", ressource=file_path, error=str(e))
)
return loaded_toml
@ -129,10 +133,11 @@ def read_ldif(file_path, filtred_entries=[]):
parser = LDIFPar(f)
parser.parse()
except IOError as e:
raise MoulinetteError('cannot_open_file', file=file_path, error=str(e))
raise MoulinetteError("cannot_open_file", file=file_path, error=str(e))
except Exception as e:
raise MoulinetteError('unknown_error_reading_file',
file=file_path, error=str(e))
raise MoulinetteError(
"unknown_error_reading_file", file=file_path, error=str(e)
)
return parser.all_records
@ -148,23 +153,34 @@ def write_to_file(file_path, data, file_mode="w"):
file_mode -- Mode used when writing the file. Option meant to be used
by append_to_file to avoid duplicating the code of this function.
"""
assert isinstance(data, basestring) or isinstance(data, list), "Error: data '%s' should be either a string or a list but is of type '%s'" % (data, type(data))
assert not os.path.isdir(file_path), "Error: file_path '%s' point to a dir, it should be a file" % file_path
assert os.path.isdir(os.path.dirname(file_path)), "Error: the path ('%s') base dir ('%s') is not a dir" % (file_path, os.path.dirname(file_path))
assert isinstance(data, basestring) or isinstance(data, list), (
"Error: data '%s' should be either a string or a list but is of type '%s'"
% (data, type(data))
)
assert not os.path.isdir(file_path), (
"Error: file_path '%s' point to a dir, it should be a file" % file_path
)
assert os.path.isdir(os.path.dirname(file_path)), (
"Error: the path ('%s') base dir ('%s') is not a dir"
% (file_path, os.path.dirname(file_path))
)
# If data is a list, check elements are strings and build a single string
if not isinstance(data, basestring):
for element in data:
assert isinstance(element, basestring), "Error: element '%s' should be a string but is of type '%s' instead" % (element, type(element))
data = '\n'.join(data)
assert isinstance(element, basestring), (
"Error: element '%s' should be a string but is of type '%s' instead"
% (element, type(element))
)
data = "\n".join(data)
try:
with open(file_path, file_mode) as f:
f.write(data)
except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e))
raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e))
raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def append_to_file(file_path, data):
@ -189,19 +205,30 @@ def write_to_json(file_path, data):
"""
# Assumptions
assert isinstance(file_path, basestring), "Error: file_path '%s' should be a string but is of type '%s' instead" % (file_path, type(file_path))
assert isinstance(data, dict) or isinstance(data, list), "Error: data '%s' should be a dict or a list but is of type '%s' instead" % (data, type(data))
assert not os.path.isdir(file_path), "Error: file_path '%s' point to a dir, it should be a file" % file_path
assert os.path.isdir(os.path.dirname(file_path)), "Error: the path ('%s') base dir ('%s') is not a dir" % (file_path, os.path.dirname(file_path))
assert isinstance(file_path, basestring), (
"Error: file_path '%s' should be a string but is of type '%s' instead"
% (file_path, type(file_path))
)
assert isinstance(data, dict) or isinstance(data, list), (
"Error: data '%s' should be a dict or a list but is of type '%s' instead"
% (data, type(data))
)
assert not os.path.isdir(file_path), (
"Error: file_path '%s' point to a dir, it should be a file" % file_path
)
assert os.path.isdir(os.path.dirname(file_path)), (
"Error: the path ('%s') base dir ('%s') is not a dir"
% (file_path, os.path.dirname(file_path))
)
# Write dict to file
try:
with open(file_path, "w") as f:
json.dump(data, f)
except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e))
raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e))
raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def write_to_yaml(file_path, data):
@ -223,9 +250,9 @@ def write_to_yaml(file_path, data):
with open(file_path, "w") as f:
yaml.safe_dump(data, f, default_flow_style=False)
except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e))
raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e))
raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def mkdir(path, mode=0o777, parents=False, uid=None, gid=None, force=False):
@ -245,7 +272,7 @@ def mkdir(path, mode=0o777, parents=False, uid=None, gid=None, force=False):
"""
if os.path.exists(path) and not force:
raise OSError(errno.EEXIST, m18n.g('folder_exists', path=path))
raise OSError(errno.EEXIST, m18n.g("folder_exists", path=path))
if parents:
# Create parents directories as needed
@ -290,14 +317,14 @@ def chown(path, uid=None, gid=None, recursive=False):
try:
uid = getpwnam(uid).pw_uid
except KeyError:
raise MoulinetteError('unknown_user', user=uid)
raise MoulinetteError("unknown_user", user=uid)
elif uid is None:
uid = -1
if isinstance(gid, basestring):
try:
gid = grp.getgrnam(gid).gr_gid
except KeyError:
raise MoulinetteError('unknown_group', group=gid)
raise MoulinetteError("unknown_group", group=gid)
elif gid is None:
gid = -1
@ -310,7 +337,9 @@ def chown(path, uid=None, gid=None, recursive=False):
for f in files:
os.chown(os.path.join(root, f), uid, gid)
except Exception as e:
raise MoulinetteError('error_changing_file_permissions', path=path, error=str(e))
raise MoulinetteError(
"error_changing_file_permissions", path=path, error=str(e)
)
def chmod(path, mode, fmode=None, recursive=False):
@ -334,7 +363,9 @@ def chmod(path, mode, fmode=None, recursive=False):
for f in files:
os.chmod(os.path.join(root, f), fmode)
except Exception as e:
raise MoulinetteError('error_changing_file_permissions', path=path, error=str(e))
raise MoulinetteError(
"error_changing_file_permissions", path=path, error=str(e)
)
def rm(path, recursive=False, force=False):
@ -353,4 +384,4 @@ def rm(path, recursive=False, force=False):
os.remove(path)
except OSError as e:
if not force:
raise MoulinetteError('error_removing', path=path, error=str(e))
raise MoulinetteError("error_removing", path=path, error=str(e))

View file

@ -3,8 +3,18 @@ import logging
# import all constants because other modules try to import them from this
# module because SUCCESS is defined in this module
from logging import (addLevelName, setLoggerClass, Logger, getLogger, NOTSET, # noqa
DEBUG, INFO, WARNING, ERROR, CRITICAL)
from logging import (
addLevelName,
setLoggerClass,
Logger,
getLogger,
NOTSET, # noqa
DEBUG,
INFO,
WARNING,
ERROR,
CRITICAL,
)
# Global configuration and functions -----------------------------------
@ -12,27 +22,20 @@ from logging import (addLevelName, setLoggerClass, Logger, getLogger, NOTSET, #
SUCCESS = 25
DEFAULT_LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'simple': {
'format': '%(asctime)-15s %(levelname)-8s %(name)s - %(message)s'
},
},
'handlers': {
'console': {
'level': 'DEBUG',
'formatter': 'simple',
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stdout',
},
},
'loggers': {
'moulinette': {
'level': 'DEBUG',
'handlers': ['console'],
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {"format": "%(asctime)-15s %(levelname)-8s %(name)s - %(message)s"},
},
"handlers": {
"console": {
"level": "DEBUG",
"formatter": "simple",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
},
},
"loggers": {"moulinette": {"level": "DEBUG", "handlers": ["console"],},},
}
@ -46,7 +49,7 @@ def configure_logging(logging_config=None):
from logging.config import dictConfig
# add custom logging level and class
addLevelName(SUCCESS, 'SUCCESS')
addLevelName(SUCCESS, "SUCCESS")
setLoggerClass(MoulinetteLogger)
# load configuration from dict
@ -65,7 +68,7 @@ def getHandlersByClass(classinfo, limit=0):
return o
handlers.append(o)
if limit != 0 and len(handlers) > limit:
return handlers[:limit - 1]
return handlers[: limit - 1]
return handlers
@ -79,6 +82,7 @@ class MoulinetteLogger(Logger):
LogRecord extra and can be used with the ActionFilter.
"""
action_id = None
def success(self, msg, *args, **kwargs):
@ -105,11 +109,11 @@ class MoulinetteLogger(Logger):
def _log(self, *args, **kwargs):
"""Append action_id if available to the extra."""
if self.action_id is not None:
extra = kwargs.get('extra', {})
if 'action_id' not in extra:
extra = kwargs.get("extra", {})
if "action_id" not in extra:
# FIXME: Get real action_id instead of logger/current one
extra['action_id'] = _get_action_id()
kwargs['extra'] = extra
extra["action_id"] = _get_action_id()
kwargs["extra"] = extra
return Logger._log(self, *args, **kwargs)
@ -120,7 +124,7 @@ action_id = 0
def _get_action_id():
return '%d.%d' % (pid, action_id)
return "%d.%d" % (pid, action_id)
def start_action_logging():
@ -146,7 +150,7 @@ def getActionLogger(name=None, logger=None, action_id=None):
"""
if not name and not logger:
raise ValueError('Either a name or a logger must be specified')
raise ValueError("Either a name or a logger must be specified")
logger = logger or getLogger(name)
logger.action_id = action_id if action_id else _get_action_id()
@ -164,15 +168,15 @@ class ActionFilter(object):
"""
def __init__(self, message_key='fmessage', strict=False):
def __init__(self, message_key="fmessage", strict=False):
self.message_key = message_key
self.strict = strict
def filter(self, record):
msg = record.getMessage()
action_id = record.__dict__.get('action_id', None)
action_id = record.__dict__.get("action_id", None)
if action_id is not None:
msg = '[{:s}] {:s}'.format(action_id, msg)
msg = "[{:s}] {:s}".format(action_id, msg)
elif self.strict:
return False
record.__dict__[self.message_key] = msg

View file

@ -15,6 +15,7 @@ def download_text(url, timeout=30, expected_status_code=200):
None to ignore the status code.
"""
import requests # lazy loading this module for performance reasons
# Assumptions
assert isinstance(url, str)
@ -23,22 +24,21 @@ def download_text(url, timeout=30, expected_status_code=200):
r = requests.get(url, timeout=timeout)
# Invalid URL
except requests.exceptions.ConnectionError:
raise MoulinetteError('invalid_url', url=url)
raise MoulinetteError("invalid_url", url=url)
# SSL exceptions
except requests.exceptions.SSLError:
raise MoulinetteError('download_ssl_error', url=url)
raise MoulinetteError("download_ssl_error", url=url)
# Timeout exceptions
except requests.exceptions.Timeout:
raise MoulinetteError('download_timeout', url=url)
raise MoulinetteError("download_timeout", url=url)
# Unknown stuff
except Exception as e:
raise MoulinetteError('download_unknown_error',
url=url, error=str(e))
raise MoulinetteError("download_unknown_error", url=url, error=str(e))
# Assume error if status code is not 200 (OK)
if expected_status_code is not None \
and r.status_code != expected_status_code:
raise MoulinetteError('download_bad_status_code',
url=url, code=str(r.status_code))
if expected_status_code is not None and r.status_code != expected_status_code:
raise MoulinetteError(
"download_bad_status_code", url=url, code=str(r.status_code)
)
return r.text
@ -59,6 +59,6 @@ def download_json(url, timeout=30, expected_status_code=200):
try:
loaded_json = json.loads(text)
except ValueError as e:
raise MoulinetteError('corrupted_json', ressource=url, error=e)
raise MoulinetteError("corrupted_json", ressource=url, error=e)
return loaded_json

View file

@ -11,6 +11,7 @@ except ImportError:
from shlex import quote # Python3 >= 3.3
from .stream import async_file_reading
quote # This line is here to avoid W0611 PEP8 error (see comments above)
# Prevent to import subprocess only for common classes
@ -19,6 +20,7 @@ CalledProcessError = subprocess.CalledProcessError
# Alternative subprocess methods ---------------------------------------
def check_output(args, stderr=subprocess.STDOUT, shell=True, **kwargs):
"""Run command with arguments and return its output as a byte string
@ -31,6 +33,7 @@ def check_output(args, stderr=subprocess.STDOUT, shell=True, **kwargs):
# Call with stream access ----------------------------------------------
def call_async_output(args, callback, **kwargs):
"""Run command and provide its output asynchronously
@ -52,10 +55,9 @@ def call_async_output(args, callback, **kwargs):
Exit status of the command
"""
for a in ['stdout', 'stderr']:
for a in ["stdout", "stderr"]:
if a in kwargs:
raise ValueError('%s argument not allowed, '
'it will be overridden.' % a)
raise ValueError("%s argument not allowed, " "it will be overridden." % a)
if "stdinfo" in kwargs and kwargs["stdinfo"] is not None:
assert len(callback) == 3
@ -72,16 +74,16 @@ def call_async_output(args, callback, **kwargs):
# Validate callback argument
if isinstance(callback, tuple):
if len(callback) < 2:
raise ValueError('callback argument should be a 2-tuple')
kwargs['stdout'] = kwargs['stderr'] = subprocess.PIPE
raise ValueError("callback argument should be a 2-tuple")
kwargs["stdout"] = kwargs["stderr"] = subprocess.PIPE
separate_stderr = True
elif callable(callback):
kwargs['stdout'] = subprocess.PIPE
kwargs['stderr'] = subprocess.STDOUT
kwargs["stdout"] = subprocess.PIPE
kwargs["stderr"] = subprocess.STDOUT
separate_stderr = False
callback = (callback,)
else:
raise ValueError('callback argument must be callable or a 2-tuple')
raise ValueError("callback argument must be callable or a 2-tuple")
# Run the command
p = subprocess.Popen(args, **kwargs)
@ -101,7 +103,7 @@ def call_async_output(args, callback, **kwargs):
stderr_consum.process_next_line()
if stdinfo:
stdinfo_consum.process_next_line()
time.sleep(.1)
time.sleep(0.1)
stderr_reader.join()
# clear the queues
stdout_consum.process_current_queue()
@ -111,7 +113,7 @@ def call_async_output(args, callback, **kwargs):
else:
while not stdout_reader.eof():
stdout_consum.process_current_queue()
time.sleep(.1)
time.sleep(0.1)
stdout_reader.join()
# clear the queue
stdout_consum.process_current_queue()
@ -131,15 +133,15 @@ def call_async_output(args, callback, **kwargs):
while time.time() - start < 10:
if p.poll() is not None:
return p.poll()
time.sleep(.1)
time.sleep(0.1)
return p.poll()
# Call multiple commands -----------------------------------------------
def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
**kwargs):
def run_commands(cmds, callback=None, separate_stderr=False, shell=True, **kwargs):
"""Run multiple commands with error management
Run a list of commands and allow to manage how to treat errors either
@ -176,18 +178,18 @@ def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
# stdout and stderr are specified by this code later, so they cannot be
# overriden by user input
for a in ['stdout', 'stderr']:
for a in ["stdout", "stderr"]:
if a in kwargs:
raise ValueError('%s argument not allowed, '
'it will be overridden.' % a)
raise ValueError("%s argument not allowed, " "it will be overridden." % a)
# If no callback specified...
if callback is None:
# Raise CalledProcessError on command failure
def callback(r, c, o):
raise CalledProcessError(r, c, o)
elif not callable(callback):
raise ValueError('callback argument must be callable')
raise ValueError("callback argument must be callable")
# Manage stderr
if separate_stderr:
@ -201,8 +203,9 @@ def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
error = 0
for cmd in cmds:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=_stderr, shell=shell, **kwargs)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=_stderr, shell=shell, **kwargs
)
output = _get_output(*process.communicate())
retcode = process.poll()

View file

@ -3,11 +3,12 @@ from json.encoder import JSONEncoder
import datetime
import pytz
logger = logging.getLogger('moulinette.utils.serialize')
logger = logging.getLogger("moulinette.utils.serialize")
# JSON utilities -------------------------------------------------------
class JSONExtendedEncoder(JSONEncoder):
"""Extended JSON encoder
@ -24,8 +25,7 @@ class JSONExtendedEncoder(JSONEncoder):
def default(self, o):
"""Return a serializable object"""
# Convert compatible containers into list
if isinstance(o, set) or (
hasattr(o, '__iter__') and hasattr(o, 'next')):
if isinstance(o, set) or (hasattr(o, "__iter__") and hasattr(o, "next")):
return list(o)
# Display the date in its iso format ISO-8601 Internet Profile (RFC 3339)
@ -35,6 +35,9 @@ class JSONExtendedEncoder(JSONEncoder):
return o.isoformat()
# Return the repr for object that json can't encode
logger.warning('cannot properly encode in JSON the object %s, '
'returned repr is: %r', type(o), o)
logger.warning(
"cannot properly encode in JSON the object %s, " "returned repr is: %r",
type(o),
o,
)
return repr(o)

View file

@ -7,6 +7,7 @@ from multiprocessing.queues import SimpleQueue
# Read from a stream ---------------------------------------------------
class AsynchronousFileReader(Process):
"""
@ -20,8 +21,8 @@ class AsynchronousFileReader(Process):
"""
def __init__(self, fd, queue):
assert hasattr(queue, 'put')
assert hasattr(queue, 'empty')
assert hasattr(queue, "put")
assert hasattr(queue, "empty")
assert isinstance(fd, int) or callable(fd.readline)
Process.__init__(self)
self._fd = fd
@ -34,7 +35,7 @@ class AsynchronousFileReader(Process):
# Typically that's for stdout/stderr pipes
# We can read the stuff easily with 'readline'
if not isinstance(self._fd, int):
for line in iter(self._fd.readline, ''):
for line in iter(self._fd.readline, ""):
self._queue.put(line)
# Else, it got opened with os.open() and we have to read it
@ -52,10 +53,10 @@ class AsynchronousFileReader(Process):
# If we have data, extract a line (ending with \n) and feed
# it to the consumer
if data and '\n' in data:
lines = data.split('\n')
if data and "\n" in data:
lines = data.split("\n")
self._queue.put(lines[0])
data = '\n'.join(lines[1:])
data = "\n".join(lines[1:])
else:
time.sleep(0.05)
@ -75,7 +76,6 @@ class AsynchronousFileReader(Process):
class Consummer(object):
def __init__(self, queue, callback):
self.queue = queue
self.callback = callback

View file

@ -6,6 +6,7 @@ import binascii
# Pattern searching ----------------------------------------------------
def search(pattern, text, count=0, flags=0):
"""Search for pattern in a text
@ -46,7 +47,7 @@ def searchf(pattern, path, count=0, flags=re.MULTILINE):
content by using the search function.
"""
with open(path, 'r+') as f:
with open(path, "r+") as f:
data = mmap.mmap(f.fileno(), 0)
match = search(pattern, data, count, flags)
data.close()
@ -55,6 +56,7 @@ def searchf(pattern, path, count=0, flags=re.MULTILINE):
# Text formatting ------------------------------------------------------
def prependlines(text, prepend):
"""Prepend a string to each line of a text"""
lines = text.splitlines(True)
@ -63,6 +65,7 @@ def prependlines(text, prepend):
# Randomize ------------------------------------------------------------
def random_ascii(length=20):
"""Return a random ascii string"""
return binascii.hexlify(os.urandom(length)).decode('ascii')
return binascii.hexlify(os.urandom(length)).decode("ascii")

View file

@ -10,9 +10,9 @@ def patch_init(moulinette):
"""Configure moulinette to use the YunoHost namespace."""
old_init = moulinette.core.Moulinette18n.__init__
def monkey_path_i18n_init(self, package, default_locale='en'):
def monkey_path_i18n_init(self, package, default_locale="en"):
old_init(self, package, default_locale)
self.load_namespace('moulinette')
self.load_namespace("moulinette")
moulinette.core.Moulinette18n.__init__ = monkey_path_i18n_init
@ -23,7 +23,7 @@ def patch_translate(moulinette):
def new_translate(self, key, *args, **kwargs):
if key not in self._translations[self.default_locale].keys():
message = 'Unable to retrieve key %s for default locale!' % key
message = "Unable to retrieve key %s for default locale!" % key
raise KeyError(message)
return old_translate(self, key, *args, **kwargs)
@ -38,59 +38,46 @@ def patch_translate(moulinette):
def patch_logging(moulinette):
"""Configure logging to use the custom logger."""
handlers = set(['tty', 'api'])
handlers = set(["tty", "api"])
root_handlers = set(handlers)
level = 'INFO'
tty_level = 'INFO'
level = "INFO"
tty_level = "INFO"
return {
'version': 1,
'disable_existing_loggers': True,
'formatters': {
'tty-debug': {
'format': '%(relativeCreated)-4d %(fmessage)s'
},
'precise': {
'format': '%(asctime)-15s %(levelname)-8s %(name)s %(funcName)s - %(fmessage)s' # noqa
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"tty-debug": {"format": "%(relativeCreated)-4d %(fmessage)s"},
"precise": {
"format": "%(asctime)-15s %(levelname)-8s %(name)s %(funcName)s - %(fmessage)s" # noqa
},
},
'filters': {
'action': {
'()': 'moulinette.utils.log.ActionFilter',
"filters": {"action": {"()": "moulinette.utils.log.ActionFilter",},},
"handlers": {
"api": {
"level": level,
"class": "moulinette.interfaces.api.APIQueueHandler",
},
"tty": {
"level": tty_level,
"class": "moulinette.interfaces.cli.TTYHandler",
"formatter": "",
},
},
'handlers': {
'api': {
'level': level,
'class': 'moulinette.interfaces.api.APIQueueHandler',
},
'tty': {
'level': tty_level,
'class': 'moulinette.interfaces.cli.TTYHandler',
'formatter': '',
"loggers": {
"moulinette": {"level": level, "handlers": [], "propagate": True,},
"moulinette.interface": {
"level": level,
"handlers": handlers,
"propagate": False,
},
},
'loggers': {
'moulinette': {
'level': level,
'handlers': [],
'propagate': True,
},
'moulinette.interface': {
'level': level,
'handlers': handlers,
'propagate': False,
},
},
'root': {
'level': level,
'handlers': root_handlers,
},
"root": {"level": level, "handlers": root_handlers,},
}
@pytest.fixture(scope='session', autouse=True)
@pytest.fixture(scope="session", autouse=True)
def moulinette(tmp_path_factory):
import moulinette
@ -100,9 +87,9 @@ def moulinette(tmp_path_factory):
tmp_cache = str(tmp_path_factory.mktemp("cache"))
tmp_data = str(tmp_path_factory.mktemp("data"))
tmp_lib = str(tmp_path_factory.mktemp("lib"))
os.environ['MOULINETTE_CACHE_DIR'] = tmp_cache
os.environ['MOULINETTE_DATA_DIR'] = tmp_data
os.environ['MOULINETTE_LIB_DIR'] = tmp_lib
os.environ["MOULINETTE_CACHE_DIR"] = tmp_cache
os.environ["MOULINETTE_DATA_DIR"] = tmp_data
os.environ["MOULINETTE_LIB_DIR"] = tmp_lib
shutil.copytree("./test/actionsmap", "%s/actionsmap" % tmp_data)
shutil.copytree("./test/src", "%s/%s" % (tmp_lib, namespace))
shutil.copytree("./test/locales", "%s/%s/locales" % (tmp_lib, namespace))
@ -111,10 +98,7 @@ def moulinette(tmp_path_factory):
patch_translate(moulinette)
logging = patch_logging(moulinette)
moulinette.init(
logging_config=logging,
_from_source=False
)
moulinette.init(logging_config=logging, _from_source=False)
return moulinette
@ -129,12 +113,13 @@ def moulinette_webapi(moulinette):
# sure why :|
def return_true(self, cookie, request):
return True
CookiePolicy.return_ok_secure = return_true
moulinette_webapi = moulinette.core.init_interface(
'api',
kwargs={'routes': {}, 'use_websocket': False},
actionsmap={'namespaces': ["moulitest"], 'use_cache': True}
"api",
kwargs={"routes": {}, "use_websocket": False},
actionsmap={"namespaces": ["moulitest"], "use_cache": True},
)
return TestApp(moulinette_webapi._app)
@ -142,16 +127,16 @@ def moulinette_webapi(moulinette):
@pytest.fixture
def test_file(tmp_path):
test_text = 'foo\nbar\n'
test_file = tmp_path / 'test.txt'
test_text = "foo\nbar\n"
test_file = tmp_path / "test.txt"
test_file.write_bytes(test_text)
return test_file
@pytest.fixture
def test_json(tmp_path):
test_json = json.dumps({'foo': 'bar'})
test_file = tmp_path / 'test.json'
test_json = json.dumps({"foo": "bar"})
test_file = tmp_path / "test.json"
test_file.write_bytes(test_json)
return test_file
@ -163,4 +148,4 @@ def user():
@pytest.fixture
def test_url():
return 'https://some.test.url/yolo.txt'
return "https://some.test.url/yolo.txt"

View file

@ -5,7 +5,7 @@ from moulinette.actionsmap import (
AskParameter,
PatternParameter,
RequiredParameter,
ActionsMap
ActionsMap,
)
from moulinette.interfaces import BaseActionsMapParser
from moulinette.core import MoulinetteError
@ -13,75 +13,73 @@ from moulinette.core import MoulinetteError
@pytest.fixture
def iface():
return 'iface'
return "iface"
def test_comment_parameter_bad_bool_value(iface, caplog):
comment = CommentParameter(iface)
assert comment.validate(True, 'a') == 'a'
assert any('expecting a non-empty string' in message for message in caplog.messages)
assert comment.validate(True, "a") == "a"
assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_comment_parameter_bad_empty_string(iface, caplog):
comment = CommentParameter(iface)
assert comment.validate('', 'a') == 'a'
assert any('expecting a non-empty string' in message for message in caplog.messages)
assert comment.validate("", "a") == "a"
assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_comment_parameter_bad_type(iface):
comment = CommentParameter(iface)
with pytest.raises(TypeError):
comment.validate({}, 'b')
comment.validate({}, "b")
def test_ask_parameter_bad_bool_value(iface, caplog):
ask = AskParameter(iface)
assert ask.validate(True, 'a') == 'a'
assert any('expecting a non-empty string' in message for message in caplog.messages)
assert ask.validate(True, "a") == "a"
assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_ask_parameter_bad_empty_string(iface, caplog):
ask = AskParameter(iface)
assert ask.validate('', 'a') == 'a'
assert any('expecting a non-empty string' in message for message in caplog.messages)
assert ask.validate("", "a") == "a"
assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_ask_parameter_bad_type(iface):
ask = AskParameter(iface)
with pytest.raises(TypeError):
ask.validate({}, 'b')
ask.validate({}, "b")
def test_pattern_parameter_bad_str_value(iface, caplog):
pattern = PatternParameter(iface)
assert pattern.validate('', 'a') == ['', 'pattern_not_match']
assert any('expecting a list' in message for message in caplog.messages)
assert pattern.validate("", "a") == ["", "pattern_not_match"]
assert any("expecting a list" in message for message in caplog.messages)
@pytest.mark.parametrize('iface', [
[],
['pattern_alone'],
['pattern', 'message', 'extra stuff']
])
@pytest.mark.parametrize(
"iface", [[], ["pattern_alone"], ["pattern", "message", "extra stuff"]]
)
def test_pattern_parameter_bad_list_len(iface):
pattern = PatternParameter(iface)
with pytest.raises(TypeError):
pattern.validate(iface, 'a')
pattern.validate(iface, "a")
def test_required_paremeter_missing_value(iface):
required = RequiredParameter(iface)
with pytest.raises(MoulinetteError) as exception:
required(True, 'a', '')
assert 'is required' in str(exception)
required(True, "a", "")
assert "is required" in str(exception)
def test_actions_map_unknown_authenticator(monkeypatch, tmp_path):
monkeypatch.setenv('MOULINETTE_DATA_DIR', str(tmp_path))
actionsmap_dir = actionsmap_dir = tmp_path / 'actionsmap'
monkeypatch.setenv("MOULINETTE_DATA_DIR", str(tmp_path))
actionsmap_dir = actionsmap_dir = tmp_path / "actionsmap"
actionsmap_dir.mkdir()
amap = ActionsMap(BaseActionsMapParser)
with pytest.raises(ValueError) as exception:
amap.get_authenticator_for_profile('unknown')
assert 'Unknown authenticator' in str(exception)
amap.get_authenticator_for_profile("unknown")
assert "Unknown authenticator" in str(exception)

View file

@ -7,19 +7,28 @@ def login(webapi, csrf=False, profile=None, status=200):
if profile:
data["profile"] = profile
return webapi.post("/login", data,
status=status,
headers=None if csrf else {"X-Requested-With": ""})
return webapi.post(
"/login",
data,
status=status,
headers=None if csrf else {"X-Requested-With": ""},
)
def test_request_no_auth_needed(moulinette_webapi):
assert moulinette_webapi.get("/test-auth/none", status=200).text == '"some_data_from_none"'
assert (
moulinette_webapi.get("/test-auth/none", status=200).text
== '"some_data_from_none"'
)
def test_request_with_auth_but_not_logged(moulinette_webapi):
assert moulinette_webapi.get("/test-auth/default", status=401).text == "Authentication required"
assert (
moulinette_webapi.get("/test-auth/default", status=401).text
== "Authentication required"
)
def test_login(moulinette_webapi):
@ -29,8 +38,10 @@ def test_login(moulinette_webapi):
assert "session.id" in moulinette_webapi.cookies
assert "session.tokens" in moulinette_webapi.cookies
cache_session_default = os.environ['MOULINETTE_CACHE_DIR'] + "/session/default/"
assert moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(cache_session_default)
cache_session_default = os.environ["MOULINETTE_CACHE_DIR"] + "/session/default/"
assert moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(
cache_session_default
)
def test_login_csrf_attempt(moulinette_webapi):
@ -57,7 +68,10 @@ def test_login_then_legit_request(moulinette_webapi):
login(moulinette_webapi)
assert moulinette_webapi.get("/test-auth/default", status=200).text == '"some_data_from_default"'
assert (
moulinette_webapi.get("/test-auth/default", status=200).text
== '"some_data_from_default"'
)
def test_login_then_logout(moulinette_webapi):
@ -66,7 +80,12 @@ def test_login_then_logout(moulinette_webapi):
moulinette_webapi.get("/logout", status=200)
cache_session_default = os.environ['MOULINETTE_CACHE_DIR'] + "/session/default/"
assert not moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(cache_session_default)
cache_session_default = os.environ["MOULINETTE_CACHE_DIR"] + "/session/default/"
assert not moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(
cache_session_default
)
assert moulinette_webapi.get("/test-auth/default", status=401).text == "Authentication required"
assert (
moulinette_webapi.get("/test-auth/default", status=401).text
== "Authentication required"
)

View file

@ -2,11 +2,11 @@ import os.path
def test_open_cachefile_creates(monkeypatch, tmp_path):
monkeypatch.setenv('MOULINETTE_CACHE_DIR', str(tmp_path))
monkeypatch.setenv("MOULINETTE_CACHE_DIR", str(tmp_path))
from moulinette.cache import open_cachefile
handle = open_cachefile('foo.cache', mode='w')
handle = open_cachefile("foo.cache", mode="w")
assert handle.mode == 'w'
assert handle.name == os.path.join(str(tmp_path), 'foo.cache')
assert handle.mode == "w"
assert handle.name == os.path.join(str(tmp_path), "foo.cache")

View file

@ -4,150 +4,156 @@ import pytest
from moulinette import m18n
from moulinette.core import MoulinetteError
from moulinette.utils.filesystem import (append_to_file, read_file, read_json,
rm, write_to_file, write_to_json)
from moulinette.utils.filesystem import (
append_to_file,
read_file,
read_json,
rm,
write_to_file,
write_to_json,
)
def test_read_file(test_file):
content = read_file(str(test_file))
assert content == 'foo\nbar\n'
assert content == "foo\nbar\n"
def test_read_file_missing_file():
bad_file = 'doesnt-exist'
bad_file = "doesnt-exist"
with pytest.raises(MoulinetteError) as exception:
read_file(bad_file)
translation = m18n.g('file_not_exist', path=bad_file)
translation = m18n.g("file_not_exist", path=bad_file)
expected_msg = translation.format(path=bad_file)
assert expected_msg in str(exception)
def test_read_file_cannot_read_ioerror(test_file, mocker):
error = 'foobar'
error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)):
with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception:
read_file(str(test_file))
translation = m18n.g('cannot_open_file', file=str(test_file), error=error)
translation = m18n.g("cannot_open_file", file=str(test_file), error=error)
expected_msg = translation.format(file=str(test_file), error=error)
assert expected_msg in str(exception)
def test_read_json(test_json):
content = read_json(str(test_json))
assert 'foo' in content.keys()
assert content['foo'] == 'bar'
assert "foo" in content.keys()
assert content["foo"] == "bar"
def test_read_json_cannot_read(test_json, mocker):
error = 'foobar'
error = "foobar"
with mocker.patch('json.loads', side_effect=ValueError(error)):
with mocker.patch("json.loads", side_effect=ValueError(error)):
with pytest.raises(MoulinetteError) as exception:
read_json(str(test_json))
translation = m18n.g('corrupted_json', ressource=str(test_json), error=error)
translation = m18n.g("corrupted_json", ressource=str(test_json), error=error)
expected_msg = translation.format(ressource=str(test_json), error=error)
assert expected_msg in str(exception)
def test_write_to_existing_file(test_file):
write_to_file(str(test_file), 'yolo\nswag')
assert read_file(str(test_file)) == 'yolo\nswag'
write_to_file(str(test_file), "yolo\nswag")
assert read_file(str(test_file)) == "yolo\nswag"
def test_write_to_new_file(tmp_path):
new_file = tmp_path / 'newfile.txt'
new_file = tmp_path / "newfile.txt"
write_to_file(str(new_file), 'yolo\nswag')
write_to_file(str(new_file), "yolo\nswag")
assert os.path.exists(str(new_file))
assert read_file(str(new_file)) == 'yolo\nswag'
assert read_file(str(new_file)) == "yolo\nswag"
def test_write_to_existing_file_bad_perms(test_file, mocker):
error = 'foobar'
error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)):
with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception:
write_to_file(str(test_file), 'yolo\nswag')
write_to_file(str(test_file), "yolo\nswag")
translation = m18n.g('cannot_write_file', file=str(test_file), error=error)
translation = m18n.g("cannot_write_file", file=str(test_file), error=error)
expected_msg = translation.format(file=str(test_file), error=error)
assert expected_msg in str(exception)
def test_write_cannot_write_folder(tmp_path):
with pytest.raises(AssertionError):
write_to_file(str(tmp_path), 'yolo\nswag')
write_to_file(str(tmp_path), "yolo\nswag")
def test_write_cannot_write_to_non_existant_folder():
with pytest.raises(AssertionError):
write_to_file('/toto/test', 'yolo\nswag')
write_to_file("/toto/test", "yolo\nswag")
def test_write_to_file_with_a_list(test_file):
write_to_file(str(test_file), ['yolo', 'swag'])
assert read_file(str(test_file)) == 'yolo\nswag'
write_to_file(str(test_file), ["yolo", "swag"])
assert read_file(str(test_file)) == "yolo\nswag"
def test_append_to_existing_file(test_file):
append_to_file(str(test_file), 'yolo\nswag')
assert read_file(str(test_file)) == 'foo\nbar\nyolo\nswag'
append_to_file(str(test_file), "yolo\nswag")
assert read_file(str(test_file)) == "foo\nbar\nyolo\nswag"
def test_append_to_new_file(tmp_path):
new_file = tmp_path / 'newfile.txt'
new_file = tmp_path / "newfile.txt"
append_to_file(str(new_file), 'yolo\nswag')
append_to_file(str(new_file), "yolo\nswag")
assert os.path.exists(str(new_file))
assert read_file(str(new_file)) == 'yolo\nswag'
assert read_file(str(new_file)) == "yolo\nswag"
def text_write_dict_to_json(tmp_path):
new_file = tmp_path / 'newfile.json'
new_file = tmp_path / "newfile.json"
dummy_dict = {'foo': 42, 'bar': ['a', 'b', 'c']}
dummy_dict = {"foo": 42, "bar": ["a", "b", "c"]}
write_to_json(str(new_file), dummy_dict)
_json = read_json(str(new_file))
assert 'foo' in _json.keys()
assert 'bar' in _json.keys()
assert "foo" in _json.keys()
assert "bar" in _json.keys()
assert _json['foo'] == 42
assert _json['bar'] == ['a', 'b', 'c']
assert _json["foo"] == 42
assert _json["bar"] == ["a", "b", "c"]
def text_write_list_to_json(tmp_path):
new_file = tmp_path / 'newfile.json'
new_file = tmp_path / "newfile.json"
dummy_list = ['foo', 'bar', 'baz']
dummy_list = ["foo", "bar", "baz"]
write_to_json(str(new_file), dummy_list)
_json = read_json(str(new_file))
assert _json == ['foo', 'bar', 'baz']
assert _json == ["foo", "bar", "baz"]
def test_write_to_json_bad_perms(test_json, mocker):
error = 'foobar'
error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)):
with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception:
write_to_json(str(test_json), {'a': 1})
write_to_json(str(test_json), {"a": 1})
translation = m18n.g('cannot_write_file', file=str(test_json), error=error)
translation = m18n.g("cannot_write_file", file=str(test_json), error=error)
expected_msg = translation.format(file=str(test_json), error=error)
assert expected_msg in str(exception)
def test_write_json_cannot_write_to_non_existant_folder():
with pytest.raises(AssertionError):
write_to_json('/toto/test.json', ['a', 'b'])
write_to_json("/toto/test.json", ["a", "b"])
def test_remove_file(test_file):
@ -157,13 +163,13 @@ def test_remove_file(test_file):
def test_remove_file_bad_perms(test_file, mocker):
error = 'foobar'
error = "foobar"
with mocker.patch('os.remove', side_effect=OSError(error)):
with mocker.patch("os.remove", side_effect=OSError(error)):
with pytest.raises(MoulinetteError) as exception:
rm(str(test_file))
translation = m18n.g('error_removing', path=str(test_file), error=error)
translation = m18n.g("error_removing", path=str(test_file), error=error)
expected_msg = translation.format(path=str(test_file), error=error)
assert expected_msg in str(exception)

View file

@ -8,19 +8,19 @@ from moulinette.utils.network import download_json, download_text
def test_download(test_url):
with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='some text')
mock.register_uri("GET", test_url, text="some text")
fetched_text = download_text(test_url)
assert fetched_text == 'some text'
assert fetched_text == "some text"
def test_download_bad_url():
with pytest.raises(MoulinetteError):
download_text('Nowhere')
download_text("Nowhere")
def test_download_404(test_url):
with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, status_code=404)
mock.register_uri("GET", test_url, status_code=404)
with pytest.raises(MoulinetteError):
download_text(test_url)
@ -28,7 +28,7 @@ def test_download_404(test_url):
def test_download_ssl_error(test_url):
with requests_mock.Mocker() as mock:
exception = requests.exceptions.SSLError
mock.register_uri('GET', test_url, exc=exception)
mock.register_uri("GET", test_url, exc=exception)
with pytest.raises(MoulinetteError):
download_text(test_url)
@ -36,21 +36,21 @@ def test_download_ssl_error(test_url):
def test_download_timeout(test_url):
with requests_mock.Mocker() as mock:
exception = requests.exceptions.ConnectTimeout
mock.register_uri('GET', test_url, exc=exception)
mock.register_uri("GET", test_url, exc=exception)
with pytest.raises(MoulinetteError):
download_text(test_url)
def test_download_json(test_url):
with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='{"foo":"bar"}')
mock.register_uri("GET", test_url, text='{"foo":"bar"}')
fetched_json = download_json(test_url)
assert 'foo' in fetched_json.keys()
assert fetched_json['foo'] == 'bar'
assert "foo" in fetched_json.keys()
assert fetched_json["foo"] == "bar"
def test_download_json_bad_json(test_url):
with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='notjsonlol')
mock.register_uri("GET", test_url, text="notjsonlol")
with pytest.raises(MoulinetteError):
download_json(test_url)

View file

@ -8,10 +8,10 @@ from moulinette.utils.process import run_commands
def test_run_shell_command_list(test_file):
assert os.path.exists(str(test_file))
run_commands(['rm -f %s' % str(test_file)])
run_commands(["rm -f %s" % str(test_file)])
assert not os.path.exists(str(test_file))
def test_run_shell_bad_cmd():
with pytest.raises(CalledProcessError):
run_commands(['yolo swag'])
run_commands(["yolo swag"])

View file

@ -7,8 +7,8 @@ def test_json_extended_encoder(caplog):
assert encoder.default(set([1, 2, 3])) == [1, 2, 3]
assert encoder.default(dt(1917, 3, 8)) == '1917-03-08T00:00:00+00:00'
assert encoder.default(dt(1917, 3, 8)) == "1917-03-08T00:00:00+00:00"
assert encoder.default(None) == 'None'
assert encoder.default(None) == "None"
for message in caplog.messages:
assert 'cannot properly encode in JSON' in message
assert "cannot properly encode in JSON" in message

View file

@ -2,19 +2,19 @@ from moulinette.utils.text import search, searchf, prependlines, random_ascii
def test_search():
assert search('a', 'a a a') == ['a', 'a', 'a']
assert search('a', 'a a a', count=2) == ['a', 'a']
assert not search('a', 'c c d')
assert search("a", "a a a") == ["a", "a", "a"]
assert search("a", "a a a", count=2) == ["a", "a"]
assert not search("a", "c c d")
def test_searchf(test_file):
assert searchf('bar', str(test_file)) == ['bar']
assert not searchf('baz', str(test_file))
assert searchf("bar", str(test_file)) == ["bar"]
assert not searchf("baz", str(test_file))
def test_prependlines():
assert prependlines('abc\nedf\nghi', 'XXX') == 'XXXabc\nXXXedf\nXXXghi'
assert prependlines('', 'XXX') == 'XXX'
assert prependlines("abc\nedf\nghi", "XXX") == "XXXabc\nXXXedf\nXXXghi"
assert prependlines("", "XXX") == "XXX"
def test_random_ascii():