Apply Black on all the code base

This commit is contained in:
decentral1se 2019-11-25 17:21:13 +01:00 committed by Alexandre Aubin
parent 6f5daa0c38
commit 54b8cab133
27 changed files with 1171 additions and 875 deletions

View file

@ -1,16 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from moulinette.core import init_interface, MoulinetteError, MoulinetteSignals, Moulinette18n from moulinette.core import (
init_interface,
MoulinetteError,
MoulinetteSignals,
Moulinette18n,
)
from moulinette.globals import init_moulinette_env from moulinette.globals import init_moulinette_env
__title__ = 'moulinette' __title__ = "moulinette"
__version__ = '0.1' __version__ = "0.1"
__author__ = ['Kload', __author__ = ["Kload", "jlebleu", "titoko", "beudbeud", "npze"]
'jlebleu', __license__ = "AGPL 3.0"
'titoko',
'beudbeud',
'npze']
__license__ = 'AGPL 3.0'
__credits__ = """ __credits__ = """
Copyright (C) 2014 YUNOHOST.ORG Copyright (C) 2014 YUNOHOST.ORG
@ -28,8 +29,13 @@ __credits__ = """
along with this program; if not, see http://www.gnu.org/licenses along with this program; if not, see http://www.gnu.org/licenses
""" """
__all__ = [ __all__ = [
'init', 'api', 'cli', 'm18n', 'env', "init",
'init_interface', 'MoulinetteError', "api",
"cli",
"m18n",
"env",
"init_interface",
"MoulinetteError",
] ]
@ -40,6 +46,7 @@ m18n = Moulinette18n()
# Package functions # Package functions
def init(logging_config=None, **kwargs): def init(logging_config=None, **kwargs):
"""Package initialization """Package initialization
@ -61,13 +68,15 @@ def init(logging_config=None, **kwargs):
configure_logging(logging_config) configure_logging(logging_config)
# Add library directory to python path # Add library directory to python path
sys.path.insert(0, init_moulinette_env()['LIB_DIR']) sys.path.insert(0, init_moulinette_env()["LIB_DIR"])
# Easy access to interfaces # Easy access to interfaces
def api(namespaces, host='localhost', port=80, routes={},
use_websocket=True, use_cache=True): def api(
namespaces, host="localhost", port=80, routes={}, use_websocket=True, use_cache=True
):
"""Web server (API) interface """Web server (API) interface
Run a HTTP server with the moulinette for an API usage. Run a HTTP server with the moulinette for an API usage.
@ -84,29 +93,33 @@ def api(namespaces, host='localhost', port=80, routes={},
""" """
try: try:
moulinette = init_interface('api', moulinette = init_interface(
kwargs={ "api",
'routes': routes, kwargs={"routes": routes, "use_websocket": use_websocket},
'use_websocket': use_websocket actionsmap={"namespaces": namespaces, "use_cache": use_cache},
},
actionsmap={
'namespaces': namespaces,
'use_cache': use_cache
}
) )
moulinette.run(host, port) moulinette.run(host, port)
except MoulinetteError as e: except MoulinetteError as e:
import logging import logging
logging.getLogger(namespaces[0]).error(e.strerror) logging.getLogger(namespaces[0]).error(e.strerror)
return e.errno if hasattr(e, "errno") else 1 return e.errno if hasattr(e, "errno") else 1
except KeyboardInterrupt: except KeyboardInterrupt:
import logging import logging
logging.getLogger(namespaces[0]).info(m18n.g('operation_interrupted'))
logging.getLogger(namespaces[0]).info(m18n.g("operation_interrupted"))
return 0 return 0
def cli(namespaces, args, use_cache=True, output_as=None, def cli(
password=None, timeout=None, parser_kwargs={}): namespaces,
args,
use_cache=True,
output_as=None,
password=None,
timeout=None,
parser_kwargs={},
):
"""Command line interface """Command line interface
Execute an action with the moulinette from the CLI and print its Execute an action with the moulinette from the CLI and print its
@ -125,16 +138,18 @@ def cli(namespaces, args, use_cache=True, output_as=None,
""" """
try: try:
moulinette = init_interface('cli', moulinette = init_interface(
"cli",
actionsmap={ actionsmap={
'namespaces': namespaces, "namespaces": namespaces,
'use_cache': use_cache, "use_cache": use_cache,
'parser_kwargs': parser_kwargs, "parser_kwargs": parser_kwargs,
}, },
) )
moulinette.run(args, output_as=output_as, password=password, timeout=timeout) moulinette.run(args, output_as=output_as, password=password, timeout=timeout)
except MoulinetteError as e: except MoulinetteError as e:
import logging import logging
logging.getLogger(namespaces[0]).error(e.strerror) logging.getLogger(namespaces[0]).error(e.strerror)
return 1 return 1
return 0 return 0

View file

@ -12,19 +12,18 @@ from importlib import import_module
from moulinette import m18n, msignals from moulinette import m18n, msignals
from moulinette.cache import open_cachefile from moulinette.cache import open_cachefile
from moulinette.globals import init_moulinette_env from moulinette.globals import init_moulinette_env
from moulinette.core import (MoulinetteError, MoulinetteLock) from moulinette.core import MoulinetteError, MoulinetteLock
from moulinette.interfaces import ( from moulinette.interfaces import BaseActionsMapParser, GLOBAL_SECTION, TO_RETURN_PROP
BaseActionsMapParser, GLOBAL_SECTION, TO_RETURN_PROP
)
from moulinette.utils.log import start_action_logging from moulinette.utils.log import start_action_logging
logger = logging.getLogger('moulinette.actionsmap') logger = logging.getLogger("moulinette.actionsmap")
# Extra parameters ---------------------------------------------------- # Extra parameters ----------------------------------------------------
# Extra parameters definition # Extra parameters definition
class _ExtraParameter(object): class _ExtraParameter(object):
""" """
@ -87,7 +86,7 @@ class _ExtraParameter(object):
class CommentParameter(_ExtraParameter): class CommentParameter(_ExtraParameter):
name = "comment" name = "comment"
skipped_iface = ['api'] skipped_iface = ["api"]
def __call__(self, message, arg_name, arg_value): def __call__(self, message, arg_name, arg_value):
return msignals.display(m18n.n(message)) return msignals.display(m18n.n(message))
@ -96,12 +95,15 @@ class CommentParameter(_ExtraParameter):
def validate(klass, value, arg_name): def validate(klass, value, arg_name):
# Deprecated boolean or empty string # Deprecated boolean or empty string
if isinstance(value, bool) or (isinstance(value, str) and not value): if isinstance(value, bool) or (isinstance(value, str) and not value):
logger.warning("expecting a non-empty string for extra parameter '%s' of " logger.warning(
"argument '%s'", klass.name, arg_name) "expecting a non-empty string for extra parameter '%s' of "
"argument '%s'",
klass.name,
arg_name,
)
value = arg_name value = arg_name
elif not isinstance(value, str): elif not isinstance(value, str):
raise TypeError("parameter value must be a string, got %r" raise TypeError("parameter value must be a string, got %r" % value)
% value)
return value return value
@ -114,8 +116,9 @@ class AskParameter(_ExtraParameter):
when asking the argument value. when asking the argument value.
""" """
name = 'ask'
skipped_iface = ['api'] name = "ask"
skipped_iface = ["api"]
def __call__(self, message, arg_name, arg_value): def __call__(self, message, arg_name, arg_value):
if arg_value: if arg_value:
@ -131,12 +134,15 @@ class AskParameter(_ExtraParameter):
def validate(klass, value, arg_name): def validate(klass, value, arg_name):
# Deprecated boolean or empty string # Deprecated boolean or empty string
if isinstance(value, bool) or (isinstance(value, str) and not value): if isinstance(value, bool) or (isinstance(value, str) and not value):
logger.warning("expecting a non-empty string for extra parameter '%s' of " logger.warning(
"argument '%s'", klass.name, arg_name) "expecting a non-empty string for extra parameter '%s' of "
"argument '%s'",
klass.name,
arg_name,
)
value = arg_name value = arg_name
elif not isinstance(value, str): elif not isinstance(value, str):
raise TypeError("parameter value must be a string, got %r" raise TypeError("parameter value must be a string, got %r" % value)
% value)
return value return value
@ -149,7 +155,8 @@ class PasswordParameter(AskParameter):
when asking the password. when asking the password.
""" """
name = 'password'
name = "password"
def __call__(self, message, arg_name, arg_value): def __call__(self, message, arg_name, arg_value):
if arg_value: if arg_value:
@ -171,40 +178,45 @@ class PatternParameter(_ExtraParameter):
the message to display if it doesn't match. the message to display if it doesn't match.
""" """
name = 'pattern'
name = "pattern"
def __call__(self, arguments, arg_name, arg_value): def __call__(self, arguments, arg_name, arg_value):
pattern, message = (arguments[0], arguments[1]) pattern, message = (arguments[0], arguments[1])
# Use temporarly utf-8 encoded value # Use temporarly utf-8 encoded value
try: try:
v = unicode(arg_value, 'utf-8') v = unicode(arg_value, "utf-8")
except: except:
v = arg_value v = arg_value
if v and not re.match(pattern, v or '', re.UNICODE): if v and not re.match(pattern, v or "", re.UNICODE):
logger.debug("argument value '%s' for '%s' doesn't match pattern '%s'", logger.debug(
v, arg_name, pattern) "argument value '%s' for '%s' doesn't match pattern '%s'",
v,
arg_name,
pattern,
)
# Attempt to retrieve message translation # Attempt to retrieve message translation
msg = m18n.n(message) msg = m18n.n(message)
if msg == message: if msg == message:
msg = m18n.g(message) msg = m18n.g(message)
raise MoulinetteError('invalid_argument', raise MoulinetteError("invalid_argument", argument=arg_name, error=msg)
argument=arg_name, error=msg)
return arg_value return arg_value
@staticmethod @staticmethod
def validate(value, arg_name): def validate(value, arg_name):
# Deprecated string type # Deprecated string type
if isinstance(value, str): if isinstance(value, str):
logger.warning("expecting a list as extra parameter 'pattern' of " logger.warning(
"argument '%s'", arg_name) "expecting a list as extra parameter 'pattern' of " "argument '%s'",
value = [value, 'pattern_not_match'] arg_name,
)
value = [value, "pattern_not_match"]
elif not isinstance(value, list) or len(value) != 2: elif not isinstance(value, list) or len(value) != 2:
raise TypeError("parameter value must be a list, got %r" raise TypeError("parameter value must be a list, got %r" % value)
% value)
return value return value
@ -216,21 +228,19 @@ class RequiredParameter(_ExtraParameter):
The value of this parameter must be a boolean which is set to False by The value of this parameter must be a boolean which is set to False by
default. default.
""" """
name = 'required'
name = "required"
def __call__(self, required, arg_name, arg_value): def __call__(self, required, arg_name, arg_value):
if required and (arg_value is None or arg_value == ''): if required and (arg_value is None or arg_value == ""):
logger.debug("argument '%s' is required", logger.debug("argument '%s' is required", arg_name)
arg_name) raise MoulinetteError("argument_required", argument=arg_name)
raise MoulinetteError('argument_required',
argument=arg_name)
return arg_value return arg_value
@staticmethod @staticmethod
def validate(value, arg_name): def validate(value, arg_name):
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("parameter value must be a list, got %r" raise TypeError("parameter value must be a list, got %r" % value)
% value)
return value return value
@ -239,8 +249,13 @@ The list of available extra parameters classes. It will keep to this list
order on argument parsing. order on argument parsing.
""" """
extraparameters_list = [CommentParameter, AskParameter, PasswordParameter, extraparameters_list = [
RequiredParameter, PatternParameter] CommentParameter,
AskParameter,
PasswordParameter,
RequiredParameter,
PatternParameter,
]
# Extra parameters argument Parser # Extra parameters argument Parser
@ -265,7 +280,7 @@ class ExtraArgumentParser(object):
if iface in klass.skipped_iface: if iface in klass.skipped_iface:
continue continue
self.extra[klass.name] = klass self.extra[klass.name] = klass
logger.debug('extra parameter classes loaded: %s', self.extra.keys()) logger.debug("extra parameter classes loaded: %s", self.extra.keys())
def validate(self, arg_name, parameters): def validate(self, arg_name, parameters):
""" """
@ -287,9 +302,14 @@ class ExtraArgumentParser(object):
# Validate parameter value # Validate parameter value
parameters[p] = klass.validate(v, arg_name) parameters[p] = klass.validate(v, arg_name)
except Exception as e: except Exception as e:
logger.error("unable to validate extra parameter '%s' " logger.error(
"for argument '%s': %s", p, arg_name, e) "unable to validate extra parameter '%s' "
raise MoulinetteError('error_see_log') "for argument '%s': %s",
p,
arg_name,
e,
)
raise MoulinetteError("error_see_log")
return parameters return parameters
@ -354,12 +374,15 @@ class ExtraArgumentParser(object):
# Main class ---------------------------------------------------------- # Main class ----------------------------------------------------------
def ordered_yaml_load(stream): def ordered_yaml_load(stream):
class OrderedLoader(yaml.Loader): class OrderedLoader(yaml.Loader):
pass pass
OrderedLoader.add_constructor( OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
lambda loader, node: OrderedDict(loader.construct_pairs(node))) lambda loader, node: OrderedDict(loader.construct_pairs(node)),
)
return yaml.load(stream, OrderedLoader) return yaml.load(stream, OrderedLoader)
@ -387,16 +410,15 @@ class ActionsMap(object):
""" """
def __init__(self, parser_class, namespaces=[], use_cache=True, def __init__(self, parser_class, namespaces=[], use_cache=True, parser_kwargs={}):
parser_kwargs={}):
if not issubclass(parser_class, BaseActionsMapParser): if not issubclass(parser_class, BaseActionsMapParser):
raise ValueError("Invalid parser class '%s'" % parser_class.__name__) raise ValueError("Invalid parser class '%s'" % parser_class.__name__)
self.parser_class = parser_class self.parser_class = parser_class
self.use_cache = use_cache self.use_cache = use_cache
moulinette_env = init_moulinette_env() moulinette_env = init_moulinette_env()
DATA_DIR = moulinette_env['DATA_DIR'] DATA_DIR = moulinette_env["DATA_DIR"]
CACHE_DIR = moulinette_env['CACHE_DIR'] CACHE_DIR = moulinette_env["CACHE_DIR"]
if len(namespaces) == 0: if len(namespaces) == 0:
namespaces = self.get_namespaces() namespaces = self.get_namespaces()
@ -406,13 +428,13 @@ class ActionsMap(object):
for n in namespaces: for n in namespaces:
logger.debug("loading actions map namespace '%s'", n) logger.debug("loading actions map namespace '%s'", n)
actionsmap_yml = '%s/actionsmap/%s.yml' % (DATA_DIR, n) actionsmap_yml = "%s/actionsmap/%s.yml" % (DATA_DIR, n)
actionsmap_yml_stat = os.stat(actionsmap_yml) actionsmap_yml_stat = os.stat(actionsmap_yml)
actionsmap_pkl = '%s/actionsmap/%s-%d-%d.pkl' % ( actionsmap_pkl = "%s/actionsmap/%s-%d-%d.pkl" % (
CACHE_DIR, CACHE_DIR,
n, n,
actionsmap_yml_stat.st_size, actionsmap_yml_stat.st_size,
actionsmap_yml_stat.st_mtime actionsmap_yml_stat.st_mtime,
) )
if use_cache and os.path.exists(actionsmap_pkl): if use_cache and os.path.exists(actionsmap_pkl):
@ -447,16 +469,18 @@ class ActionsMap(object):
# Fetch the configuration for the authenticator module as defined in the actionmap # Fetch the configuration for the authenticator module as defined in the actionmap
try: try:
auth_conf = self.parser.global_conf['authenticator'][auth_profile] auth_conf = self.parser.global_conf["authenticator"][auth_profile]
except KeyError: except KeyError:
raise ValueError("Unknown authenticator profile '%s'" % auth_profile) raise ValueError("Unknown authenticator profile '%s'" % auth_profile)
# Load and initialize the authenticator module # Load and initialize the authenticator module
try: try:
mod = import_module('moulinette.authenticators.%s' % auth_conf["vendor"]) mod = import_module("moulinette.authenticators.%s" % auth_conf["vendor"])
except ImportError: except ImportError:
logger.exception("unable to load authenticator vendor '%s'", auth_conf["vendor"]) logger.exception(
raise MoulinetteError('error_see_log') "unable to load authenticator vendor '%s'", auth_conf["vendor"]
)
raise MoulinetteError("error_see_log")
else: else:
return mod.Authenticator(**auth_conf) return mod.Authenticator(**auth_conf)
@ -471,7 +495,7 @@ class ActionsMap(object):
auth = msignals.authenticate(authenticator) auth = msignals.authenticate(authenticator)
if not auth.is_authenticated: if not auth.is_authenticated:
raise MoulinetteError('authentication_required_long') raise MoulinetteError("authentication_required_long")
def process(self, args, timeout=None, **kwargs): def process(self, args, timeout=None, **kwargs):
""" """
@ -492,7 +516,7 @@ class ActionsMap(object):
arguments = vars(self.parser.parse_args(args, **kwargs)) arguments = vars(self.parser.parse_args(args, **kwargs))
# Retrieve tid and parse arguments with extra parameters # Retrieve tid and parse arguments with extra parameters
tid = arguments.pop('_tid') tid = arguments.pop("_tid")
arguments = self.extraparser.parse_args(tid, arguments) arguments = self.extraparser.parse_args(tid, arguments)
# Return immediately if a value is defined # Return immediately if a value is defined
@ -502,38 +526,55 @@ class ActionsMap(object):
# Retrieve action information # Retrieve action information
if len(tid) == 4: if len(tid) == 4:
namespace, category, subcategory, action = tid namespace, category, subcategory, action = tid
func_name = '%s_%s_%s' % (category, subcategory.replace('-', '_'), action.replace('-', '_')) func_name = "%s_%s_%s" % (
full_action_name = "%s.%s.%s.%s" % (namespace, category, subcategory, action) category,
subcategory.replace("-", "_"),
action.replace("-", "_"),
)
full_action_name = "%s.%s.%s.%s" % (
namespace,
category,
subcategory,
action,
)
else: else:
assert len(tid) == 3 assert len(tid) == 3
namespace, category, action = tid namespace, category, action = tid
subcategory = None subcategory = None
func_name = '%s_%s' % (category, action.replace('-', '_')) func_name = "%s_%s" % (category, action.replace("-", "_"))
full_action_name = "%s.%s.%s" % (namespace, category, action) full_action_name = "%s.%s.%s" % (namespace, category, action)
# Lock the moulinette for the namespace # Lock the moulinette for the namespace
with MoulinetteLock(namespace, timeout): with MoulinetteLock(namespace, timeout):
start = time() start = time()
try: try:
mod = __import__('%s.%s' % (namespace, category), mod = __import__(
globals=globals(), level=0, "%s.%s" % (namespace, category),
fromlist=[func_name]) globals=globals(),
logger.debug('loading python module %s took %.3fs', level=0,
'%s.%s' % (namespace, category), time() - start) fromlist=[func_name],
)
logger.debug(
"loading python module %s took %.3fs",
"%s.%s" % (namespace, category),
time() - start,
)
func = getattr(mod, func_name) func = getattr(mod, func_name)
except (AttributeError, ImportError): except (AttributeError, ImportError):
logger.exception("unable to load function %s.%s", logger.exception("unable to load function %s.%s", namespace, func_name)
namespace, func_name) raise MoulinetteError("error_see_log")
raise MoulinetteError('error_see_log')
else: else:
log_id = start_action_logging() log_id = start_action_logging()
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
# Log arguments in debug mode only for safety reasons # Log arguments in debug mode only for safety reasons
logger.info('processing action [%s]: %s with args=%s', logger.info(
log_id, full_action_name, arguments) "processing action [%s]: %s with args=%s",
log_id,
full_action_name,
arguments,
)
else: else:
logger.info('processing action [%s]: %s', logger.info("processing action [%s]: %s", log_id, full_action_name)
log_id, full_action_name)
# Load translation and process the action # Load translation and process the action
m18n.load_namespace(namespace) m18n.load_namespace(namespace)
@ -542,8 +583,7 @@ class ActionsMap(object):
return func(**arguments) return func(**arguments)
finally: finally:
stop = time() stop = time()
logger.debug('action [%s] executed in %.3fs', logger.debug("action [%s] executed in %.3fs", log_id, stop - start)
log_id, stop - start)
@staticmethod @staticmethod
def get_namespaces(): def get_namespaces():
@ -557,10 +597,10 @@ class ActionsMap(object):
namespaces = [] namespaces = []
moulinette_env = init_moulinette_env() moulinette_env = init_moulinette_env()
DATA_DIR = moulinette_env['DATA_DIR'] DATA_DIR = moulinette_env["DATA_DIR"]
for f in os.listdir('%s/actionsmap' % DATA_DIR): for f in os.listdir("%s/actionsmap" % DATA_DIR):
if f.endswith('.yml'): if f.endswith(".yml"):
namespaces.append(f[:-4]) namespaces.append(f[:-4])
return namespaces return namespaces
@ -577,8 +617,8 @@ class ActionsMap(object):
""" """
moulinette_env = init_moulinette_env() moulinette_env = init_moulinette_env()
CACHE_DIR = moulinette_env['CACHE_DIR'] CACHE_DIR = moulinette_env["CACHE_DIR"]
DATA_DIR = moulinette_env['DATA_DIR'] DATA_DIR = moulinette_env["DATA_DIR"]
actionsmaps = {} actionsmaps = {}
if not namespaces: if not namespaces:
@ -589,23 +629,23 @@ class ActionsMap(object):
logger.debug("generating cache for actions map namespace '%s'", n) logger.debug("generating cache for actions map namespace '%s'", n)
# Read actions map from yaml file # Read actions map from yaml file
am_file = '%s/actionsmap/%s.yml' % (DATA_DIR, n) am_file = "%s/actionsmap/%s.yml" % (DATA_DIR, n)
with open(am_file, 'r') as f: with open(am_file, "r") as f:
actionsmaps[n] = ordered_yaml_load(f) actionsmaps[n] = ordered_yaml_load(f)
# at installation, cachedir might not exists # at installation, cachedir might not exists
if os.path.exists('%s/actionsmap/' % CACHE_DIR): if os.path.exists("%s/actionsmap/" % CACHE_DIR):
# clean old cached files # clean old cached files
for i in os.listdir('%s/actionsmap/' % CACHE_DIR): for i in os.listdir("%s/actionsmap/" % CACHE_DIR):
if i.endswith(".pkl"): if i.endswith(".pkl"):
os.remove('%s/actionsmap/%s' % (CACHE_DIR, i)) os.remove("%s/actionsmap/%s" % (CACHE_DIR, i))
# Cache actions map into pickle file # Cache actions map into pickle file
am_file_stat = os.stat(am_file) am_file_stat = os.stat(am_file)
pkl = '%s-%d-%d.pkl' % (n, am_file_stat.st_size, am_file_stat.st_mtime) pkl = "%s-%d-%d.pkl" % (n, am_file_stat.st_size, am_file_stat.st_mtime)
with open_cachefile(pkl, 'w', subdir='actionsmap') as f: with open_cachefile(pkl, "w", subdir="actionsmap") as f:
pickle.dump(actionsmaps[n], f) pickle.dump(actionsmaps[n], f)
return actionsmaps return actionsmaps
@ -646,86 +686,97 @@ class ActionsMap(object):
# * actionsmap is the actual actionsmap that we care about # * actionsmap is the actual actionsmap that we care about
for namespace, actionsmap in actionsmaps.items(): for namespace, actionsmap in actionsmaps.items():
# Retrieve global parameters # Retrieve global parameters
_global = actionsmap.pop('_global', {}) _global = actionsmap.pop("_global", {})
# Set the global configuration to use for the parser. # Set the global configuration to use for the parser.
top_parser.set_global_conf(_global['configuration']) top_parser.set_global_conf(_global["configuration"])
if top_parser.has_global_parser(): if top_parser.has_global_parser():
top_parser.add_global_arguments(_global['arguments']) top_parser.add_global_arguments(_global["arguments"])
# category_name is stuff like "user", "domain", "hooks"... # category_name is stuff like "user", "domain", "hooks"...
# category_values is the values of this category (like actions) # category_values is the values of this category (like actions)
for category_name, category_values in actionsmap.items(): for category_name, category_values in actionsmap.items():
if "actions" in category_values: if "actions" in category_values:
actions = category_values.pop('actions') actions = category_values.pop("actions")
else: else:
actions = {} actions = {}
if "subcategories" in category_values: if "subcategories" in category_values:
subcategories = category_values.pop('subcategories') subcategories = category_values.pop("subcategories")
else: else:
subcategories = {} subcategories = {}
# Get category parser # Get category parser
category_parser = top_parser.add_category_parser(category_name, category_parser = top_parser.add_category_parser(
**category_values) category_name, **category_values
)
# action_name is like "list" of "domain list" # action_name is like "list" of "domain list"
# action_options are the values # action_options are the values
for action_name, action_options in actions.items(): for action_name, action_options in actions.items():
arguments = action_options.pop('arguments', {}) arguments = action_options.pop("arguments", {})
tid = (namespace, category_name, action_name) tid = (namespace, category_name, action_name)
# Get action parser # Get action parser
action_parser = category_parser.add_action_parser(action_name, action_parser = category_parser.add_action_parser(
tid, action_name, tid, **action_options
**action_options) )
if action_parser is None: # No parser for the action if action_parser is None: # No parser for the action
continue continue
# Store action identifier and add arguments # Store action identifier and add arguments
action_parser.set_defaults(_tid=tid) action_parser.set_defaults(_tid=tid)
action_parser.add_arguments(arguments, action_parser.add_arguments(
extraparser=self.extraparser, arguments,
format_arg_names=top_parser.format_arg_names, extraparser=self.extraparser,
validate_extra=validate_extra) format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra,
)
if 'configuration' in action_options: if "configuration" in action_options:
category_parser.set_conf(tid, action_options['configuration']) category_parser.set_conf(tid, action_options["configuration"])
# subcategory_name is like "cert" in "domain cert status" # subcategory_name is like "cert" in "domain cert status"
# subcategory_values is the values of this subcategory (like actions) # subcategory_values is the values of this subcategory (like actions)
for subcategory_name, subcategory_values in subcategories.items(): for subcategory_name, subcategory_values in subcategories.items():
actions = subcategory_values.pop('actions') actions = subcategory_values.pop("actions")
# Get subcategory parser # Get subcategory parser
subcategory_parser = category_parser.add_subcategory_parser(subcategory_name, **subcategory_values) subcategory_parser = category_parser.add_subcategory_parser(
subcategory_name, **subcategory_values
)
# action_name is like "status" of "domain cert status" # action_name is like "status" of "domain cert status"
# action_options are the values # action_options are the values
for action_name, action_options in actions.items(): for action_name, action_options in actions.items():
arguments = action_options.pop('arguments', {}) arguments = action_options.pop("arguments", {})
tid = (namespace, category_name, subcategory_name, action_name) tid = (namespace, category_name, subcategory_name, action_name)
try: try:
# Get action parser # Get action parser
action_parser = subcategory_parser.add_action_parser(action_name, tid, **action_options) action_parser = subcategory_parser.add_action_parser(
action_name, tid, **action_options
)
except AttributeError: except AttributeError:
# No parser for the action # No parser for the action
continue continue
# Store action identifier and add arguments # Store action identifier and add arguments
action_parser.set_defaults(_tid=tid) action_parser.set_defaults(_tid=tid)
action_parser.add_arguments(arguments, action_parser.add_arguments(
extraparser=self.extraparser, arguments,
format_arg_names=top_parser.format_arg_names, extraparser=self.extraparser,
validate_extra=validate_extra) format_arg_names=top_parser.format_arg_names,
validate_extra=validate_extra,
)
if 'configuration' in action_options: if "configuration" in action_options:
category_parser.set_conf(tid, action_options['configuration']) category_parser.set_conf(
tid, action_options["configuration"]
)
return top_parser return top_parser

View file

@ -8,11 +8,12 @@ import hmac
from moulinette.cache import open_cachefile, get_cachedir from moulinette.cache import open_cachefile, get_cachedir
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
logger = logging.getLogger('moulinette.authenticator') logger = logging.getLogger("moulinette.authenticator")
# Base Class ----------------------------------------------------------- # Base Class -----------------------------------------------------------
class BaseAuthenticator(object): class BaseAuthenticator(object):
"""Authenticator base representation """Authenticator base representation
@ -59,8 +60,9 @@ class BaseAuthenticator(object):
- password -- A clear text password - password -- A clear text password
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
# Authentication methods # Authentication methods
@ -95,9 +97,13 @@ class BaseAuthenticator(object):
except MoulinetteError: except MoulinetteError:
raise raise
except Exception as e: except Exception as e:
logger.exception("authentication (name: '%s', vendor: '%s') fails because '%s'", logger.exception(
self.name, self.vendor, e) "authentication (name: '%s', vendor: '%s') fails because '%s'",
raise MoulinetteError('unable_authenticate') self.name,
self.vendor,
e,
)
raise MoulinetteError("unable_authenticate")
self.is_authenticated = True self.is_authenticated = True
@ -108,6 +114,7 @@ class BaseAuthenticator(object):
self._store_session(s_id, s_token) self._store_session(s_id, s_token)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
logger.exception("unable to store session because %s", e) logger.exception("unable to store session because %s", e)
else: else:
@ -124,9 +131,13 @@ class BaseAuthenticator(object):
except MoulinetteError as e: except MoulinetteError as e:
raise raise
except Exception as e: except Exception as e:
logger.exception("authentication (name: '%s', vendor: '%s') fails because '%s'", logger.exception(
self.name, self.vendor, e) "authentication (name: '%s', vendor: '%s') fails because '%s'",
raise MoulinetteError('unable_authenticate') self.name,
self.vendor,
e,
)
raise MoulinetteError("unable_authenticate")
else: else:
self.is_authenticated = True self.is_authenticated = True
@ -134,16 +145,17 @@ class BaseAuthenticator(object):
# No credentials given, can't authenticate # No credentials given, can't authenticate
# #
else: else:
raise MoulinetteError('unable_authenticate') raise MoulinetteError("unable_authenticate")
return self return self
# Private methods # Private methods
def _open_sessionfile(self, session_id, mode='r'): def _open_sessionfile(self, session_id, mode="r"):
"""Open a session file for this instance in given mode""" """Open a session file for this instance in given mode"""
return open_cachefile('%s.asc' % session_id, mode, return open_cachefile(
subdir='session/%s' % self.name) "%s.asc" % session_id, mode, subdir="session/%s" % self.name
)
def _store_session(self, session_id, session_token): def _store_session(self, session_id, session_token):
"""Store a session to be able to use it later to reauthenticate""" """Store a session to be able to use it later to reauthenticate"""
@ -151,7 +163,7 @@ class BaseAuthenticator(object):
# We store a hash of the session_id and the session_token (the token is assumed to be secret) # We store a hash of the session_id and the session_token (the token is assumed to be secret)
to_hash = "{id}:{token}".format(id=session_id, token=session_token) to_hash = "{id}:{token}".format(id=session_id, token=session_token)
hash_ = hashlib.sha256(to_hash).hexdigest() hash_ = hashlib.sha256(to_hash).hexdigest()
with self._open_sessionfile(session_id, 'w') as f: with self._open_sessionfile(session_id, "w") as f:
f.write(hash_) f.write(hash_)
def _authenticate_session(self, session_id, session_token): def _authenticate_session(self, session_id, session_token):
@ -160,11 +172,11 @@ class BaseAuthenticator(object):
# FIXME : shouldn't we also add a check that this session file # FIXME : shouldn't we also add a check that this session file
# is not too old ? e.g. not older than 24 hours ? idk... # is not too old ? e.g. not older than 24 hours ? idk...
with self._open_sessionfile(session_id, 'r') as f: with self._open_sessionfile(session_id, "r") as f:
stored_hash = f.read() stored_hash = f.read()
except IOError as e: except IOError as e:
logger.debug("unable to retrieve session", exc_info=1) logger.debug("unable to retrieve session", exc_info=1)
raise MoulinetteError('unable_retrieve_session', exception=e) raise MoulinetteError("unable_retrieve_session", exception=e)
else: else:
# #
# session_id (or just id) : This is unique id for the current session from the user. Not too important # session_id (or just id) : This is unique id for the current session from the user. Not too important
@ -186,7 +198,7 @@ class BaseAuthenticator(object):
hash_ = hashlib.sha256(to_hash).hexdigest() hash_ = hashlib.sha256(to_hash).hexdigest()
if not hmac.compare_digest(hash_, stored_hash): if not hmac.compare_digest(hash_, stored_hash):
raise MoulinetteError('invalid_token') raise MoulinetteError("invalid_token")
else: else:
return return
@ -198,9 +210,9 @@ class BaseAuthenticator(object):
Keyword arguments: Keyword arguments:
- session_id -- The session id to clean - session_id -- The session id to clean
""" """
sessiondir = get_cachedir('session') sessiondir = get_cachedir("session")
try: try:
os.remove(os.path.join(sessiondir, self.name, '%s.asc' % session_id)) os.remove(os.path.join(sessiondir, self.name, "%s.asc" % session_id))
except OSError: except OSError:
pass pass

View file

@ -4,7 +4,7 @@ import logging
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
from moulinette.authenticators import BaseAuthenticator from moulinette.authenticators import BaseAuthenticator
logger = logging.getLogger('moulinette.authenticator.dummy') logger = logging.getLogger("moulinette.authenticator.dummy")
# Dummy authenticator implementation # Dummy authenticator implementation
@ -14,7 +14,7 @@ class Authenticator(BaseAuthenticator):
"""Dummy authenticator used for tests """Dummy authenticator used for tests
""" """
vendor = 'dummy' vendor = "dummy"
def __init__(self, name, vendor, parameters, extra): def __init__(self, name, vendor, parameters, extra):
logger.debug("initialize authenticator dummy") logger.debug("initialize authenticator dummy")

View file

@ -13,11 +13,12 @@ import ldap.modlist as modlist
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
from moulinette.authenticators import BaseAuthenticator from moulinette.authenticators import BaseAuthenticator
logger = logging.getLogger('moulinette.authenticator.ldap') logger = logging.getLogger("moulinette.authenticator.ldap")
# LDAP Class Implementation -------------------------------------------- # LDAP Class Implementation --------------------------------------------
class Authenticator(BaseAuthenticator): class Authenticator(BaseAuthenticator):
"""LDAP Authenticator """LDAP Authenticator
@ -38,12 +39,18 @@ class Authenticator(BaseAuthenticator):
self.basedn = parameters["base_dn"] self.basedn = parameters["base_dn"]
self.userdn = parameters["user_rdn"] self.userdn = parameters["user_rdn"]
self.extra = extra self.extra = extra
logger.debug("initialize authenticator '%s' with: uri='%s', " logger.debug(
"base_dn='%s', user_rdn='%s'", name, self.uri, self.basedn, self.userdn) "initialize authenticator '%s' with: uri='%s', "
"base_dn='%s', user_rdn='%s'",
name,
self.uri,
self.basedn,
self.userdn,
)
super(Authenticator, self).__init__(name) super(Authenticator, self).__init__(name)
if self.userdn: if self.userdn:
if 'cn=external,cn=auth' in self.userdn: if "cn=external,cn=auth" in self.userdn:
self.authenticate(None) self.authenticate(None)
else: else:
self.con = None self.con = None
@ -55,25 +62,27 @@ class Authenticator(BaseAuthenticator):
# Implement virtual properties # Implement virtual properties
vendor = 'ldap' vendor = "ldap"
# Implement virtual methods # Implement virtual methods
def authenticate(self, password): def authenticate(self, password):
try: try:
con = ldap.ldapobject.ReconnectLDAPObject(self.uri, retry_max=10, retry_delay=0.5) con = ldap.ldapobject.ReconnectLDAPObject(
self.uri, retry_max=10, retry_delay=0.5
)
if self.userdn: if self.userdn:
if 'cn=external,cn=auth' in self.userdn: if "cn=external,cn=auth" in self.userdn:
con.sasl_non_interactive_bind_s('EXTERNAL') con.sasl_non_interactive_bind_s("EXTERNAL")
else: else:
con.simple_bind_s(self.userdn, password) con.simple_bind_s(self.userdn, password)
else: else:
con.simple_bind_s() con.simple_bind_s()
except ldap.INVALID_CREDENTIALS: except ldap.INVALID_CREDENTIALS:
raise MoulinetteError('invalid_password') raise MoulinetteError("invalid_password")
except ldap.SERVER_DOWN: except ldap.SERVER_DOWN:
logger.exception('unable to reach the server to authenticate') logger.exception("unable to reach the server to authenticate")
raise MoulinetteError('ldap_server_down') raise MoulinetteError("ldap_server_down")
# Check that we are indeed logged in with the right identity # Check that we are indeed logged in with the right identity
try: try:
@ -91,13 +100,16 @@ class Authenticator(BaseAuthenticator):
def _ensure_password_uses_strong_hash(self, password): def _ensure_password_uses_strong_hash(self, password):
# XXX this has been copy pasted from YunoHost, should we put that into moulinette? # XXX this has been copy pasted from YunoHost, should we put that into moulinette?
def _hash_user_password(password): def _hash_user_password(password):
char_set = string.ascii_uppercase + string.ascii_lowercase + string.digits + "./" char_set = (
salt = ''.join([random.SystemRandom().choice(char_set) for x in range(16)]) string.ascii_uppercase + string.ascii_lowercase + string.digits + "./"
salt = '$6$' + salt + '$' )
return '{CRYPT}' + crypt.crypt(str(password), salt) salt = "".join([random.SystemRandom().choice(char_set) for x in range(16)])
salt = "$6$" + salt + "$"
return "{CRYPT}" + crypt.crypt(str(password), salt)
hashed_password = self.search("cn=admin,dc=yunohost,dc=org", hashed_password = self.search(
attrs=["userPassword"])[0] "cn=admin,dc=yunohost,dc=org", attrs=["userPassword"]
)[0]
# post-install situation, password is not already set # post-install situation, password is not already set
if "userPassword" not in hashed_password or not hashed_password["userPassword"]: if "userPassword" not in hashed_password or not hashed_password["userPassword"]:
@ -105,14 +117,12 @@ class Authenticator(BaseAuthenticator):
# we aren't using sha-512 but something else that is weaker, proceed to upgrade # we aren't using sha-512 but something else that is weaker, proceed to upgrade
if not hashed_password["userPassword"][0].startswith("{CRYPT}$6$"): if not hashed_password["userPassword"][0].startswith("{CRYPT}$6$"):
self.update("cn=admin", { self.update("cn=admin", {"userPassword": _hash_user_password(password),})
"userPassword": _hash_user_password(password),
})
# Additional LDAP methods # Additional LDAP methods
# TODO: Review these methods # TODO: Review these methods
def search(self, base=None, filter='(objectClass=*)', attrs=['dn']): def search(self, base=None, filter="(objectClass=*)", attrs=["dn"]):
"""Search in LDAP base """Search in LDAP base
Perform an LDAP search operation with given arguments and return Perform an LDAP search operation with given arguments and return
@ -133,16 +143,22 @@ class Authenticator(BaseAuthenticator):
try: try:
result = self.con.search_s(base, ldap.SCOPE_SUBTREE, filter, attrs) result = self.con.search_s(base, ldap.SCOPE_SUBTREE, filter, attrs)
except Exception as e: except Exception as e:
logger.exception("error during LDAP search operation with: base='%s', " logger.exception(
"filter='%s', attrs=%s and exception %s", base, filter, attrs, e) "error during LDAP search operation with: base='%s', "
raise MoulinetteError('ldap_operation_error') "filter='%s', attrs=%s and exception %s",
base,
filter,
attrs,
e,
)
raise MoulinetteError("ldap_operation_error")
result_list = [] result_list = []
if not attrs or 'dn' not in attrs: if not attrs or "dn" not in attrs:
result_list = [entry for dn, entry in result] result_list = [entry for dn, entry in result]
else: else:
for dn, entry in result: for dn, entry in result:
entry['dn'] = [dn] entry["dn"] = [dn]
result_list.append(entry) result_list.append(entry)
return result_list return result_list
@ -158,15 +174,20 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError Boolean | MoulinetteError
""" """
dn = rdn + ',' + self.basedn dn = rdn + "," + self.basedn
ldif = modlist.addModlist(attr_dict) ldif = modlist.addModlist(attr_dict)
try: try:
self.con.add_s(dn, ldif) self.con.add_s(dn, ldif)
except Exception as e: except Exception as e:
logger.exception("error during LDAP add operation with: rdn='%s', " logger.exception(
"attr_dict=%s and exception %s", rdn, attr_dict, e) "error during LDAP add operation with: rdn='%s', "
raise MoulinetteError('ldap_operation_error') "attr_dict=%s and exception %s",
rdn,
attr_dict,
e,
)
raise MoulinetteError("ldap_operation_error")
else: else:
return True return True
@ -181,12 +202,16 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError Boolean | MoulinetteError
""" """
dn = rdn + ',' + self.basedn dn = rdn + "," + self.basedn
try: try:
self.con.delete_s(dn) self.con.delete_s(dn)
except Exception as e: except Exception as e:
logger.exception("error during LDAP delete operation with: rdn='%s' and exception %s", rdn, e) logger.exception(
raise MoulinetteError('ldap_operation_error') "error during LDAP delete operation with: rdn='%s' and exception %s",
rdn,
e,
)
raise MoulinetteError("ldap_operation_error")
else: else:
return True return True
@ -203,21 +228,26 @@ class Authenticator(BaseAuthenticator):
Boolean | MoulinetteError Boolean | MoulinetteError
""" """
dn = rdn + ',' + self.basedn dn = rdn + "," + self.basedn
actual_entry = self.search(base=dn, attrs=None) actual_entry = self.search(base=dn, attrs=None)
ldif = modlist.modifyModlist(actual_entry[0], attr_dict, ignore_oldexistent=1) ldif = modlist.modifyModlist(actual_entry[0], attr_dict, ignore_oldexistent=1)
try: try:
if new_rdn: if new_rdn:
self.con.rename_s(dn, new_rdn) self.con.rename_s(dn, new_rdn)
dn = new_rdn + ',' + self.basedn dn = new_rdn + "," + self.basedn
self.con.modify_ext_s(dn, ldif) self.con.modify_ext_s(dn, ldif)
except Exception as e: except Exception as e:
logger.exception("error during LDAP update operation with: rdn='%s', " logger.exception(
"attr_dict=%s, new_rdn=%s and exception: %s", rdn, attr_dict, "error during LDAP update operation with: rdn='%s', "
new_rdn, e) "attr_dict=%s, new_rdn=%s and exception: %s",
raise MoulinetteError('ldap_operation_error') rdn,
attr_dict,
new_rdn,
e,
)
raise MoulinetteError("ldap_operation_error")
else: else:
return True return True
@ -234,11 +264,16 @@ class Authenticator(BaseAuthenticator):
""" """
attr_found = self.get_conflict(value_dict) attr_found = self.get_conflict(value_dict)
if attr_found: if attr_found:
logger.info("attribute '%s' with value '%s' is not unique", logger.info(
attr_found[0], attr_found[1]) "attribute '%s' with value '%s' is not unique",
raise MoulinetteError('ldap_attribute_already_exists', attr_found[0],
attribute=attr_found[0], attr_found[1],
value=attr_found[1]) )
raise MoulinetteError(
"ldap_attribute_already_exists",
attribute=attr_found[0],
value=attr_found[1],
)
return True return True
def get_conflict(self, value_dict, base_dn=None): def get_conflict(self, value_dict, base_dn=None):
@ -253,7 +288,7 @@ class Authenticator(BaseAuthenticator):
""" """
for attr, value in value_dict.items(): for attr, value in value_dict.items():
if not self.search(base=base_dn, filter=attr + '=' + value): if not self.search(base=base_dn, filter=attr + "=" + value):
continue continue
else: else:
return (attr, value) return (attr, value)

View file

@ -5,7 +5,7 @@ import os
from moulinette.globals import init_moulinette_env from moulinette.globals import init_moulinette_env
def get_cachedir(subdir='', make_dir=True): def get_cachedir(subdir="", make_dir=True):
"""Get the path to a cache directory """Get the path to a cache directory
Return the path to the cache directory from an optional Return the path to the cache directory from an optional
@ -16,7 +16,7 @@ def get_cachedir(subdir='', make_dir=True):
- make_dir -- False to not make directory if it not exists - make_dir -- False to not make directory if it not exists
""" """
CACHE_DIR = init_moulinette_env()['CACHE_DIR'] CACHE_DIR = init_moulinette_env()["CACHE_DIR"]
path = os.path.join(CACHE_DIR, subdir) path = os.path.join(CACHE_DIR, subdir)
@ -25,7 +25,7 @@ def get_cachedir(subdir='', make_dir=True):
return path return path
def open_cachefile(filename, mode='r', subdir=''): def open_cachefile(filename, mode="r", subdir=""):
"""Open a cache file and return a stream """Open a cache file and return a stream
Attempt to open in 'mode' the cache file 'filename' from the Attempt to open in 'mode' the cache file 'filename' from the
@ -39,6 +39,6 @@ def open_cachefile(filename, mode='r', subdir=''):
- **kwargs -- Optional arguments for get_cachedir - **kwargs -- Optional arguments for get_cachedir
""" """
cache_dir = get_cachedir(subdir, make_dir=True if mode[0] == 'w' else False) cache_dir = get_cachedir(subdir, make_dir=True if mode[0] == "w" else False)
file_path = os.path.join(cache_dir, filename) file_path = os.path.join(cache_dir, filename)
return open(file_path, mode) return open(file_path, mode)

View file

@ -11,7 +11,7 @@ import moulinette
from moulinette.globals import init_moulinette_env from moulinette.globals import init_moulinette_env
logger = logging.getLogger('moulinette.core') logger = logging.getLogger("moulinette.core")
def during_unittests_run(): def during_unittests_run():
@ -20,6 +20,7 @@ def during_unittests_run():
# Internationalization ------------------------------------------------- # Internationalization -------------------------------------------------
class Translator(object): class Translator(object):
"""Internationalization class """Internationalization class
@ -33,15 +34,16 @@ class Translator(object):
""" """
def __init__(self, locale_dir, default_locale='en'): def __init__(self, locale_dir, default_locale="en"):
self.locale_dir = locale_dir self.locale_dir = locale_dir
self.locale = default_locale self.locale = default_locale
self._translations = {} self._translations = {}
# Attempt to load default translations # Attempt to load default translations
if not self._load_translations(default_locale): if not self._load_translations(default_locale):
logger.error("unable to load locale '%s' from '%s'", logger.error(
default_locale, locale_dir) "unable to load locale '%s' from '%s'", default_locale, locale_dir
)
self.default_locale = default_locale self.default_locale = default_locale
def get_locales(self): def get_locales(self):
@ -49,7 +51,7 @@ class Translator(object):
locales = [] locales = []
for f in os.listdir(self.locale_dir): for f in os.listdir(self.locale_dir):
if f.endswith('.json'): if f.endswith(".json"):
# TODO: Validate locale # TODO: Validate locale
locales.append(f[:-5]) locales.append(f[:-5])
return locales return locales
@ -69,8 +71,11 @@ class Translator(object):
""" """
if locale not in self._translations: if locale not in self._translations:
if not self._load_translations(locale): if not self._load_translations(locale):
logger.debug("unable to load locale '%s' from '%s'", logger.debug(
self.default_locale, self.locale_dir) "unable to load locale '%s' from '%s'",
self.default_locale,
self.locale_dir,
)
# Revert to default locale # Revert to default locale
self.locale = self.default_locale self.locale = self.default_locale
@ -93,11 +98,18 @@ class Translator(object):
failed_to_format = False failed_to_format = False
if key in self._translations.get(self.locale, {}): if key in self._translations.get(self.locale, {}):
try: try:
return self._translations[self.locale][key].encode('utf-8').format(*args, **kwargs) return (
self._translations[self.locale][key]
.encode("utf-8")
.format(*args, **kwargs)
)
except KeyError as e: except KeyError as e:
unformatted_string = self._translations[self.locale][key].encode('utf-8') unformatted_string = self._translations[self.locale][key].encode(
error_message = "Failed to format translated string '%s': '%s' with arguments '%s' and '%s, raising error: %s(%s) (don't panic this is just a warning)" % ( "utf-8"
key, unformatted_string, args, kwargs, e.__class__.__name__, e )
error_message = (
"Failed to format translated string '%s': '%s' with arguments '%s' and '%s, raising error: %s(%s) (don't panic this is just a warning)"
% (key, unformatted_string, args, kwargs, e.__class__.__name__, e)
) )
if not during_unittests_run(): if not during_unittests_run():
@ -107,25 +119,37 @@ class Translator(object):
failed_to_format = True failed_to_format = True
if failed_to_format or (self.default_locale != self.locale and key in self._translations.get(self.default_locale, {})): if failed_to_format or (
logger.info("untranslated key '%s' for locale '%s'", self.default_locale != self.locale
key, self.locale) and key in self._translations.get(self.default_locale, {})
):
logger.info("untranslated key '%s' for locale '%s'", key, self.locale)
try: try:
return self._translations[self.default_locale][key].encode('utf-8').format(*args, **kwargs) return (
self._translations[self.default_locale][key]
.encode("utf-8")
.format(*args, **kwargs)
)
except KeyError as e: except KeyError as e:
unformatted_string = self._translations[self.default_locale][key].encode('utf-8') unformatted_string = self._translations[self.default_locale][
error_message = "Failed to format translatable string '%s': '%s' with arguments '%s' and '%s', raising error: %s(%s) (don't panic this is just a warning)" % ( key
key, unformatted_string, args, kwargs, e.__class__.__name__, e ].encode("utf-8")
error_message = (
"Failed to format translatable string '%s': '%s' with arguments '%s' and '%s', raising error: %s(%s) (don't panic this is just a warning)"
% (key, unformatted_string, args, kwargs, e.__class__.__name__, e)
) )
if not during_unittests_run(): if not during_unittests_run():
logger.exception(error_message) logger.exception(error_message)
else: else:
raise Exception(error_message) raise Exception(error_message)
return self._translations[self.default_locale][key].encode('utf-8') return self._translations[self.default_locale][key].encode("utf-8")
error_message = "unable to retrieve string to translate with key '%s' for default locale 'locales/%s.json' file (don't panic this is just a warning)" % (key, self.default_locale) error_message = (
"unable to retrieve string to translate with key '%s' for default locale 'locales/%s.json' file (don't panic this is just a warning)"
% (key, self.default_locale)
)
if not during_unittests_run(): if not during_unittests_run():
logger.exception(error_message) logger.exception(error_message)
@ -152,8 +176,8 @@ class Translator(object):
return True return True
try: try:
with open('%s/%s.json' % (self.locale_dir, locale), 'r') as f: with open("%s/%s.json" % (self.locale_dir, locale), "r") as f:
j = json.load(f, 'utf-8') j = json.load(f, "utf-8")
except IOError: except IOError:
return False return False
else: else:
@ -174,12 +198,12 @@ class Moulinette18n(object):
""" """
def __init__(self, default_locale='en'): def __init__(self, default_locale="en"):
self.default_locale = default_locale self.default_locale = default_locale
self.locale = default_locale self.locale = default_locale
moulinette_env = init_moulinette_env() moulinette_env = init_moulinette_env()
self.locales_dir = moulinette_env['LOCALES_DIR'] self.locales_dir = moulinette_env["LOCALES_DIR"]
# Init global translator # Init global translator
self._global = Translator(self.locales_dir, default_locale) self._global = Translator(self.locales_dir, default_locale)
@ -201,8 +225,9 @@ class Moulinette18n(object):
if namespace not in self._namespaces: if namespace not in self._namespaces:
# Create new Translator object # Create new Translator object
lib_dir = init_moulinette_env()["LIB_DIR"] lib_dir = init_moulinette_env()["LIB_DIR"]
translator = Translator('%s/%s/locales' % (lib_dir, namespace), translator = Translator(
self.default_locale) "%s/%s/locales" % (lib_dir, namespace), self.default_locale
)
translator.set_locale(self.locale) translator.set_locale(self.locale)
self._namespaces[namespace] = translator self._namespaces[namespace] = translator
@ -272,19 +297,19 @@ class MoulinetteSignals(object):
if signal not in self.signals: if signal not in self.signals:
logger.error("unknown signal '%s'", signal) logger.error("unknown signal '%s'", signal)
return return
setattr(self, '_%s' % signal, handler) setattr(self, "_%s" % signal, handler)
def clear_handler(self, signal): def clear_handler(self, signal):
"""Clear the handler of a signal""" """Clear the handler of a signal"""
if signal not in self.signals: if signal not in self.signals:
logger.error("unknown signal '%s'", signal) logger.error("unknown signal '%s'", signal)
return return
setattr(self, '_%s' % signal, self._notimplemented) setattr(self, "_%s" % signal, self._notimplemented)
# Signals definitions # Signals definitions
"""The list of available signals""" """The list of available signals"""
signals = {'authenticate', 'prompt', 'display'} signals = {"authenticate", "prompt", "display"}
def authenticate(self, authenticator): def authenticate(self, authenticator):
"""Process the authentication """Process the authentication
@ -305,7 +330,7 @@ class MoulinetteSignals(object):
return authenticator return authenticator
return self._authenticate(authenticator) return self._authenticate(authenticator)
def prompt(self, message, is_password=False, confirm=False, color='blue'): def prompt(self, message, is_password=False, confirm=False, color="blue"):
"""Prompt for a value """Prompt for a value
Prompt the interface for a parameter value which is a password Prompt the interface for a parameter value which is a password
@ -326,7 +351,7 @@ class MoulinetteSignals(object):
""" """
return self._prompt(message, is_password, confirm, color=color) return self._prompt(message, is_password, confirm, color=color)
def display(self, message, style='info'): def display(self, message, style="info"):
"""Display a message """Display a message
Display a message with a given style to the user. Display a message with a given style to the user.
@ -352,6 +377,7 @@ class MoulinetteSignals(object):
# Interfaces & Authenticators management ------------------------------- # Interfaces & Authenticators management -------------------------------
def init_interface(name, kwargs={}, actionsmap={}): def init_interface(name, kwargs={}, actionsmap={}):
"""Return a new interface instance """Return a new interface instance
@ -371,10 +397,10 @@ def init_interface(name, kwargs={}, actionsmap={}):
from moulinette.actionsmap import ActionsMap from moulinette.actionsmap import ActionsMap
try: try:
mod = import_module('moulinette.interfaces.%s' % name) mod = import_module("moulinette.interfaces.%s" % name)
except ImportError as e: except ImportError as e:
logger.exception("unable to load interface '%s' : %s", name, e) logger.exception("unable to load interface '%s' : %s", name, e)
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
else: else:
try: try:
# Retrieve interface classes # Retrieve interface classes
@ -382,22 +408,23 @@ def init_interface(name, kwargs={}, actionsmap={}):
interface = mod.Interface interface = mod.Interface
except AttributeError: except AttributeError:
logger.exception("unable to retrieve classes of interface '%s'", name) logger.exception("unable to retrieve classes of interface '%s'", name)
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
# Instantiate or retrieve ActionsMap # Instantiate or retrieve ActionsMap
if isinstance(actionsmap, dict): if isinstance(actionsmap, dict):
amap = ActionsMap(actionsmap.pop('parser', parser), **actionsmap) amap = ActionsMap(actionsmap.pop("parser", parser), **actionsmap)
elif isinstance(actionsmap, ActionsMap): elif isinstance(actionsmap, ActionsMap):
amap = actionsmap amap = actionsmap
else: else:
logger.error("invalid actionsmap value %r", actionsmap) logger.error("invalid actionsmap value %r", actionsmap)
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
return interface(amap, **kwargs) return interface(amap, **kwargs)
# Moulinette core classes ---------------------------------------------- # Moulinette core classes ----------------------------------------------
class MoulinetteError(Exception): class MoulinetteError(Exception):
"""Moulinette base exception""" """Moulinette base exception"""
@ -427,12 +454,12 @@ class MoulinetteLock(object):
""" """
def __init__(self, namespace, timeout=None, interval=.5): def __init__(self, namespace, timeout=None, interval=0.5):
self.namespace = namespace self.namespace = namespace
self.timeout = timeout self.timeout = timeout
self.interval = interval self.interval = interval
self._lockfile = '/var/run/moulinette_%s.lock' % namespace self._lockfile = "/var/run/moulinette_%s.lock" % namespace
self._stale_checked = False self._stale_checked = False
self._locked = False self._locked = False
@ -453,7 +480,7 @@ class MoulinetteLock(object):
# after 15*4 seconds, then 15*4*4 seconds... # after 15*4 seconds, then 15*4*4 seconds...
warning_treshold = 15 warning_treshold = 15
logger.debug('acquiring lock...') logger.debug("acquiring lock...")
while True: while True:
@ -470,20 +497,24 @@ class MoulinetteLock(object):
# Check locked process still exist and take lock if it doesnt # Check locked process still exist and take lock if it doesnt
# FIXME : what do in the context of multiple locks :| # FIXME : what do in the context of multiple locks :|
first_lock = lock_pids[0] first_lock = lock_pids[0]
if not os.path.exists(os.path.join('/proc', str(first_lock), 'exe')): if not os.path.exists(os.path.join("/proc", str(first_lock), "exe")):
logger.debug('stale lock file found') logger.debug("stale lock file found")
self._lock() self._lock()
break break
if self.timeout is not None and (time.time() - start_time) > self.timeout: if self.timeout is not None and (time.time() - start_time) > self.timeout:
raise MoulinetteError('instance_already_running') raise MoulinetteError("instance_already_running")
# warn the user if it's been too much time since they are waiting # warn the user if it's been too much time since they are waiting
if (time.time() - start_time) > warning_treshold: if (time.time() - start_time) > warning_treshold:
if warning_treshold == 15: if warning_treshold == 15:
logger.warning(moulinette.m18n.g('warn_the_user_about_waiting_lock')) logger.warning(
moulinette.m18n.g("warn_the_user_about_waiting_lock")
)
else: else:
logger.warning(moulinette.m18n.g('warn_the_user_about_waiting_lock_again')) logger.warning(
moulinette.m18n.g("warn_the_user_about_waiting_lock_again")
)
warning_treshold *= 4 warning_treshold *= 4
# Wait before checking again # Wait before checking again
@ -492,8 +523,8 @@ class MoulinetteLock(object):
# we have warned the user that we were waiting, for better UX also them # we have warned the user that we were waiting, for better UX also them
# that we have stop waiting and that the command is processing now # that we have stop waiting and that the command is processing now
if warning_treshold != 15: if warning_treshold != 15:
logger.warning(moulinette.m18n.g('warn_the_user_that_lock_is_acquired')) logger.warning(moulinette.m18n.g("warn_the_user_that_lock_is_acquired"))
logger.debug('lock has been acquired') logger.debug("lock has been acquired")
self._locked = True self._locked = True
def release(self): def release(self):
@ -506,16 +537,18 @@ class MoulinetteLock(object):
if os.path.exists(self._lockfile): if os.path.exists(self._lockfile):
os.unlink(self._lockfile) os.unlink(self._lockfile)
else: else:
logger.warning("Uhoh, somehow the lock %s did not exist ..." % self._lockfile) logger.warning(
logger.debug('lock has been released') "Uhoh, somehow the lock %s did not exist ..." % self._lockfile
)
logger.debug("lock has been released")
self._locked = False self._locked = False
def _lock(self): def _lock(self):
try: try:
with open(self._lockfile, 'w') as f: with open(self._lockfile, "w") as f:
f.write(str(os.getpid())) f.write(str(os.getpid()))
except IOError: except IOError:
raise MoulinetteError('root_required') raise MoulinetteError("root_required")
def _lock_PIDs(self): def _lock_PIDs(self):
@ -523,10 +556,10 @@ class MoulinetteLock(object):
return [] return []
with open(self._lockfile) as f: with open(self._lockfile) as f:
lock_pids = f.read().strip().split('\n') lock_pids = f.read().strip().split("\n")
# Make sure to convert those pids to integers # Make sure to convert those pids to integers
lock_pids = [int(pid) for pid in lock_pids if pid.strip() != ''] lock_pids = [int(pid) for pid in lock_pids if pid.strip() != ""]
return lock_pids return lock_pids

View file

@ -5,8 +5,10 @@ from os import environ
def init_moulinette_env(): def init_moulinette_env():
return { return {
'DATA_DIR': environ.get('MOULINETTE_DATA_DIR', '/usr/share/moulinette'), "DATA_DIR": environ.get("MOULINETTE_DATA_DIR", "/usr/share/moulinette"),
'LIB_DIR': environ.get('MOULINETTE_LIB_DIR', '/usr/lib/moulinette'), "LIB_DIR": environ.get("MOULINETTE_LIB_DIR", "/usr/lib/moulinette"),
'LOCALES_DIR': environ.get('MOULINETTE_LOCALES_DIR', '/usr/share/moulinette/locale'), "LOCALES_DIR": environ.get(
'CACHE_DIR': environ.get('MOULINETTE_CACHE_DIR', '/var/cache/moulinette'), "MOULINETTE_LOCALES_DIR", "/usr/share/moulinette/locale"
),
"CACHE_DIR": environ.get("MOULINETTE_CACHE_DIR", "/var/cache/moulinette"),
} }

View file

@ -9,15 +9,16 @@ from collections import deque, OrderedDict
from moulinette import msettings, m18n from moulinette import msettings, m18n
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
logger = logging.getLogger('moulinette.interface') logger = logging.getLogger("moulinette.interface")
GLOBAL_SECTION = '_global' GLOBAL_SECTION = "_global"
TO_RETURN_PROP = '_to_return' TO_RETURN_PROP = "_to_return"
CALLBACKS_PROP = '_callbacks' CALLBACKS_PROP = "_callbacks"
# Base Class ----------------------------------------------------------- # Base Class -----------------------------------------------------------
class BaseActionsMapParser(object): class BaseActionsMapParser(object):
"""Actions map's base Parser """Actions map's base Parser
@ -37,9 +38,8 @@ class BaseActionsMapParser(object):
if parent: if parent:
self._o = parent self._o = parent
else: else:
logger.debug('initializing base actions map parser for %s', logger.debug("initializing base actions map parser for %s", self.interface)
self.interface) msettings["interface"] = self.interface
msettings['interface'] = self.interface
self._o = self self._o = self
self._global_conf = {} self._global_conf = {}
@ -70,8 +70,9 @@ class BaseActionsMapParser(object):
A list of option strings A list of option strings
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
def has_global_parser(self): def has_global_parser(self):
return False return False
@ -85,8 +86,9 @@ class BaseActionsMapParser(object):
An ArgumentParser based object An ArgumentParser based object
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
def add_category_parser(self, name, **kwargs): def add_category_parser(self, name, **kwargs):
"""Add a parser for a category """Add a parser for a category
@ -100,8 +102,9 @@ class BaseActionsMapParser(object):
A BaseParser based object A BaseParser based object
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
def add_action_parser(self, name, tid, **kwargs): def add_action_parser(self, name, tid, **kwargs):
"""Add a parser for an action """Add a parser for an action
@ -116,8 +119,9 @@ class BaseActionsMapParser(object):
An ArgumentParser based object An ArgumentParser based object
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
def auth_required(self, args, **kwargs): def auth_required(self, args, **kwargs):
"""Check if authentication is required to run the requested action """Check if authentication is required to run the requested action
@ -129,8 +133,9 @@ class BaseActionsMapParser(object):
False, or the authentication profile required False, or the authentication profile required
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
def parse_args(self, args, **kwargs): def parse_args(self, args, **kwargs):
"""Parse arguments """Parse arguments
@ -145,17 +150,19 @@ class BaseActionsMapParser(object):
The populated namespace The populated namespace
""" """
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
# Arguments helpers # Arguments helpers
def prepare_action_namespace(self, tid, namespace=None): def prepare_action_namespace(self, tid, namespace=None):
"""Prepare the namespace for a given action""" """Prepare the namespace for a given action"""
# Validate tid and namespace # Validate tid and namespace
if not isinstance(tid, tuple) and \ if not isinstance(tid, tuple) and (
(namespace is None or not hasattr(namespace, TO_RETURN_PROP)): namespace is None or not hasattr(namespace, TO_RETURN_PROP)
raise MoulinetteError('invalid_usage') ):
raise MoulinetteError("invalid_usage")
elif not tid: elif not tid:
tid = GLOBAL_SECTION tid = GLOBAL_SECTION
@ -229,52 +236,65 @@ class BaseActionsMapParser(object):
# -- 'authenficate' # -- 'authenficate'
try: try:
ifaces = configuration['authenticate'] ifaces = configuration["authenticate"]
except KeyError: except KeyError:
pass pass
else: else:
if ifaces == 'all': if ifaces == "all":
conf['authenticate'] = ifaces conf["authenticate"] = ifaces
elif ifaces is False: elif ifaces is False:
conf['authenticate'] = False conf["authenticate"] = False
elif isinstance(ifaces, list): elif isinstance(ifaces, list):
# Store only if authentication is needed # Store only if authentication is needed
conf['authenticate'] = True if self.interface in ifaces else False conf["authenticate"] = True if self.interface in ifaces else False
else: else:
logger.error("expecting 'all', 'False' or a list for " logger.error(
"configuration 'authenticate', got %r", ifaces) "expecting 'all', 'False' or a list for "
raise MoulinetteError('error_see_log') "configuration 'authenticate', got %r",
ifaces,
)
raise MoulinetteError("error_see_log")
# -- 'authenticator' # -- 'authenticator'
try: try:
auth = configuration['authenticator'] auth = configuration["authenticator"]
except KeyError: except KeyError:
pass pass
else: else:
if not is_global and isinstance(auth, str): if not is_global and isinstance(auth, str):
try: try:
# Store needed authenticator profile # Store needed authenticator profile
conf['authenticator'] = self.global_conf['authenticator'][auth] conf["authenticator"] = self.global_conf["authenticator"][auth]
except KeyError: except KeyError:
logger.error("requesting profile '%s' which is undefined in " logger.error(
"global configuration of 'authenticator'", auth) "requesting profile '%s' which is undefined in "
raise MoulinetteError('error_see_log') "global configuration of 'authenticator'",
auth,
)
raise MoulinetteError("error_see_log")
elif is_global and isinstance(auth, dict): elif is_global and isinstance(auth, dict):
if len(auth) == 0: if len(auth) == 0:
logger.warning('no profile defined in global configuration ' logger.warning(
"for 'authenticator'") "no profile defined in global configuration "
"for 'authenticator'"
)
else: else:
auths = {} auths = {}
for auth_name, auth_conf in auth.items(): for auth_name, auth_conf in auth.items():
auths[auth_name] = {'name': auth_name, auths[auth_name] = {
'vendor': auth_conf.get('vendor'), "name": auth_name,
'parameters': auth_conf.get('parameters', {}), "vendor": auth_conf.get("vendor"),
'extra': {'help': auth_conf.get('help', None)}} "parameters": auth_conf.get("parameters", {}),
conf['authenticator'] = auths "extra": {"help": auth_conf.get("help", None)},
}
conf["authenticator"] = auths
else: else:
logger.error("expecting a dict of profile(s) or a profile name " logger.error(
"for configuration 'authenticator', got %r", auth) "expecting a dict of profile(s) or a profile name "
raise MoulinetteError('error_see_log') "for configuration 'authenticator', got %r",
auth,
)
raise MoulinetteError("error_see_log")
return conf return conf
@ -291,55 +311,60 @@ class BaseInterface(object):
- actionsmap -- The ActionsMap instance to connect to - actionsmap -- The ActionsMap instance to connect to
""" """
# TODO: Add common interface methods and try to standardize default ones # TODO: Add common interface methods and try to standardize default ones
def __init__(self, actionsmap): def __init__(self, actionsmap):
raise NotImplementedError("derived class '%s' must override this method" % raise NotImplementedError(
self.__class__.__name__) "derived class '%s' must override this method" % self.__class__.__name__
)
# Argument parser ------------------------------------------------------ # Argument parser ------------------------------------------------------
class _CallbackAction(argparse.Action):
def __init__(self, class _CallbackAction(argparse.Action):
option_strings, def __init__(
dest, self,
nargs=0, option_strings,
callback={}, dest,
default=argparse.SUPPRESS, nargs=0,
help=None): callback={},
if not callback or 'method' not in callback: default=argparse.SUPPRESS,
raise ValueError('callback must be provided with at least ' help=None,
'a method key') ):
if not callback or "method" not in callback:
raise ValueError("callback must be provided with at least " "a method key")
super(_CallbackAction, self).__init__( super(_CallbackAction, self).__init__(
option_strings=option_strings, option_strings=option_strings,
dest=dest, dest=dest,
nargs=nargs, nargs=nargs,
default=default, default=default,
help=help) help=help,
self.callback_method = callback.get('method') )
self.callback_kwargs = callback.get('kwargs', {}) self.callback_method = callback.get("method")
self.callback_return = callback.get('return', False) self.callback_kwargs = callback.get("kwargs", {})
logger.debug("registering new callback action '{0}' to {1}".format( self.callback_return = callback.get("return", False)
self.callback_method, option_strings)) logger.debug(
"registering new callback action '{0}' to {1}".format(
self.callback_method, option_strings
)
)
@property @property
def callback(self): def callback(self):
if not hasattr(self, '_callback'): if not hasattr(self, "_callback"):
self._retrieve_callback() self._retrieve_callback()
return self._callback return self._callback
def _retrieve_callback(self): def _retrieve_callback(self):
# Attempt to retrieve callback method # Attempt to retrieve callback method
mod_name, func_name = (self.callback_method).rsplit('.', 1) mod_name, func_name = (self.callback_method).rsplit(".", 1)
try: try:
mod = __import__(mod_name, globals=globals(), level=0, mod = __import__(mod_name, globals=globals(), level=0, fromlist=[func_name])
fromlist=[func_name])
func = getattr(mod, func_name) func = getattr(mod, func_name)
except (AttributeError, ImportError): except (AttributeError, ImportError):
raise ValueError('unable to import method {0}'.format( raise ValueError("unable to import method {0}".format(self.callback_method))
self.callback_method))
self._callback = func self._callback = func
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
@ -352,9 +377,11 @@ class _CallbackAction(argparse.Action):
# Execute callback and get returned value # Execute callback and get returned value
value = self.callback(namespace, values, **self.callback_kwargs) value = self.callback(namespace, values, **self.callback_kwargs)
except: except:
logger.exception("cannot get value from callback method " logger.exception(
"'{0}'".format(self.callback_method)) "cannot get value from callback method "
raise MoulinetteError('error_see_log') "'{0}'".format(self.callback_method)
)
raise MoulinetteError("error_see_log")
else: else:
if value: if value:
if self.callback_return: if self.callback_return:
@ -379,23 +406,22 @@ class _ExtendedSubParsersAction(argparse._SubParsersAction):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
required = kwargs.pop('required', False) required = kwargs.pop("required", False)
super(_ExtendedSubParsersAction, self).__init__(*args, **kwargs) super(_ExtendedSubParsersAction, self).__init__(*args, **kwargs)
self.required = required self.required = required
self._deprecated_command_map = {} self._deprecated_command_map = {}
def add_parser(self, name, type_=None, **kwargs): def add_parser(self, name, type_=None, **kwargs):
deprecated = kwargs.pop('deprecated', False) deprecated = kwargs.pop("deprecated", False)
deprecated_alias = kwargs.pop('deprecated_alias', []) deprecated_alias = kwargs.pop("deprecated_alias", [])
if deprecated: if deprecated:
self._deprecated_command_map[name] = None self._deprecated_command_map[name] = None
if 'help' in kwargs: if "help" in kwargs:
del kwargs['help'] del kwargs["help"]
parser = super(_ExtendedSubParsersAction, self).add_parser( parser = super(_ExtendedSubParsersAction, self).add_parser(name, **kwargs)
name, **kwargs)
# Append each deprecated command alias name # Append each deprecated command alias name
for command in deprecated_alias: for command in deprecated_alias:
@ -417,27 +443,34 @@ class _ExtendedSubParsersAction(argparse._SubParsersAction):
else: else:
# Warn the user about deprecated command # Warn the user about deprecated command
if correct_name is None: if correct_name is None:
logger.warning(m18n.g('deprecated_command', prog=parser.prog, logger.warning(
command=parser_name)) m18n.g("deprecated_command", prog=parser.prog, command=parser_name)
)
else: else:
logger.warning(m18n.g('deprecated_command_alias', logger.warning(
old=parser_name, new=correct_name, m18n.g(
prog=parser.prog)) "deprecated_command_alias",
old=parser_name,
new=correct_name,
prog=parser.prog,
)
)
values[0] = correct_name values[0] = correct_name
return super(_ExtendedSubParsersAction, self).__call__( return super(_ExtendedSubParsersAction, self).__call__(
parser, namespace, values, option_string) parser, namespace, values, option_string
)
class ExtendedArgumentParser(argparse.ArgumentParser): class ExtendedArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ExtendedArgumentParser, self).__init__(formatter_class=PositionalsFirstHelpFormatter, super(ExtendedArgumentParser, self).__init__(
*args, **kwargs) formatter_class=PositionalsFirstHelpFormatter, *args, **kwargs
)
# Register additional actions # Register additional actions
self.register('action', 'callback', _CallbackAction) self.register("action", "callback", _CallbackAction)
self.register('action', 'parsers', _ExtendedSubParsersAction) self.register("action", "parsers", _ExtendedSubParsersAction)
def enqueue_callback(self, namespace, callback, values): def enqueue_callback(self, namespace, callback, values):
queue = self._get_callbacks_queue(namespace) queue = self._get_callbacks_queue(namespace)
@ -465,30 +498,33 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
queue = list() queue = list()
return queue return queue
def add_arguments(self, arguments, extraparser, format_arg_names=None, validate_extra=True): def add_arguments(
self, arguments, extraparser, format_arg_names=None, validate_extra=True
):
for argument_name, argument_options in arguments.items(): for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context # will adapt arguments name for cli or api context
names = format_arg_names(str(argument_name), names = format_arg_names(
argument_options.pop('full', None)) str(argument_name), argument_options.pop("full", None)
)
if "type" in argument_options: if "type" in argument_options:
argument_options['type'] = eval(argument_options['type']) argument_options["type"] = eval(argument_options["type"])
if "extra" in argument_options: if "extra" in argument_options:
extra = argument_options.pop('extra') extra = argument_options.pop("extra")
argument_dest = self.add_argument(*names, **argument_options).dest argument_dest = self.add_argument(*names, **argument_options).dest
extraparser.add_argument(self.get_default("_tid"), extraparser.add_argument(
argument_dest, extra, validate_extra) self.get_default("_tid"), argument_dest, extra, validate_extra
)
continue continue
self.add_argument(*names, **argument_options) self.add_argument(*names, **argument_options)
def _get_nargs_pattern(self, action): def _get_nargs_pattern(self, action):
if action.nargs == argparse.PARSER and not action.required: if action.nargs == argparse.PARSER and not action.required:
return '([-AO]*)' return "([-AO]*)"
else: else:
return super(ExtendedArgumentParser, self)._get_nargs_pattern( return super(ExtendedArgumentParser, self)._get_nargs_pattern(action)
action)
def _get_values(self, action, arg_strings): def _get_values(self, action, arg_strings):
if action.nargs == argparse.PARSER and not action.required: if action.nargs == argparse.PARSER and not action.required:
@ -498,8 +534,7 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
else: else:
value = argparse.SUPPRESS value = argparse.SUPPRESS
else: else:
value = super(ExtendedArgumentParser, self)._get_values( value = super(ExtendedArgumentParser, self)._get_values(action, arg_strings)
action, arg_strings)
return value return value
# Adapted from : # Adapted from :
@ -508,8 +543,7 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
formatter = self._get_formatter() formatter = self._get_formatter()
# usage # usage
formatter.add_usage(self.usage, self._actions, formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
self._mutually_exclusive_groups)
# description # description
formatter.add_text(self.description) formatter.add_text(self.description)
@ -527,14 +561,30 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
subcategories_subparser = copy.copy(action_group._group_actions[0]) subcategories_subparser = copy.copy(action_group._group_actions[0])
# Filter "action"-type and "subcategory"-type commands # Filter "action"-type and "subcategory"-type commands
actions_subparser.choices = OrderedDict([(k, v) for k, v in actions_subparser.choices.items() if v.type == "action"]) actions_subparser.choices = OrderedDict(
subcategories_subparser.choices = OrderedDict([(k, v) for k, v in subcategories_subparser.choices.items() if v.type == "subcategory"]) [
(k, v)
for k, v in actions_subparser.choices.items()
if v.type == "action"
]
)
subcategories_subparser.choices = OrderedDict(
[
(k, v)
for k, v in subcategories_subparser.choices.items()
if v.type == "subcategory"
]
)
actions_choices = actions_subparser.choices.keys() actions_choices = actions_subparser.choices.keys()
subcategories_choices = subcategories_subparser.choices.keys() subcategories_choices = subcategories_subparser.choices.keys()
actions_subparser._choices_actions = [c for c in choice_actions if c.dest in actions_choices] actions_subparser._choices_actions = [
subcategories_subparser._choices_actions = [c for c in choice_actions if c.dest in subcategories_choices] c for c in choice_actions if c.dest in actions_choices
]
subcategories_subparser._choices_actions = [
c for c in choice_actions if c.dest in subcategories_choices
]
# Display each section (actions and subcategories) # Display each section (actions and subcategories)
if actions_choices != []: if actions_choices != []:
@ -569,11 +619,10 @@ class ExtendedArgumentParser(argparse.ArgumentParser):
# and fix is inspired from here : # and fix is inspired from here :
# https://stackoverflow.com/questions/26985650/argparse-do-not-catch-positional-arguments-with-nargs/26986546#26986546 # https://stackoverflow.com/questions/26985650/argparse-do-not-catch-positional-arguments-with-nargs/26986546#26986546
class PositionalsFirstHelpFormatter(argparse.HelpFormatter): class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
def _format_usage(self, usage, actions, groups, prefix): def _format_usage(self, usage, actions, groups, prefix):
if prefix is None: if prefix is None:
# TWEAK : not using gettext here... # TWEAK : not using gettext here...
prefix = 'usage: ' prefix = "usage: "
# if usage is specified, use that # if usage is specified, use that
if usage is not None: if usage is not None:
@ -581,11 +630,11 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
# if no optionals or positionals are available, usage is just prog # if no optionals or positionals are available, usage is just prog
elif usage is None and not actions: elif usage is None and not actions:
usage = '%(prog)s' % dict(prog=self._prog) usage = "%(prog)s" % dict(prog=self._prog)
# if optionals and positionals are available, calculate usage # if optionals and positionals are available, calculate usage
elif usage is None: elif usage is None:
prog = '%(prog)s' % dict(prog=self._prog) prog = "%(prog)s" % dict(prog=self._prog)
# split optionals from positionals # split optionals from positionals
optionals = [] optionals = []
@ -600,20 +649,20 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
format = self._format_actions_usage format = self._format_actions_usage
# TWEAK here : positionals first # TWEAK here : positionals first
action_usage = format(positionals + optionals, groups) action_usage = format(positionals + optionals, groups)
usage = ' '.join([s for s in [prog, action_usage] if s]) usage = " ".join([s for s in [prog, action_usage] if s])
# wrap the usage parts if it's too long # wrap the usage parts if it's too long
text_width = self._width - self._current_indent text_width = self._width - self._current_indent
if len(prefix) + len(usage) > text_width: if len(prefix) + len(usage) > text_width:
# break usage into wrappable parts # break usage into wrappable parts
part_regexp = r'\(.*?\)+|\[.*?\]+|\S+' part_regexp = r"\(.*?\)+|\[.*?\]+|\S+"
opt_usage = format(optionals, groups) opt_usage = format(optionals, groups)
pos_usage = format(positionals, groups) pos_usage = format(positionals, groups)
opt_parts = re.findall(part_regexp, opt_usage) opt_parts = re.findall(part_regexp, opt_usage)
pos_parts = re.findall(part_regexp, pos_usage) pos_parts = re.findall(part_regexp, pos_usage)
assert ' '.join(opt_parts) == opt_usage assert " ".join(opt_parts) == opt_usage
assert ' '.join(pos_parts) == pos_usage assert " ".join(pos_parts) == pos_usage
# helper for wrapping lines # helper for wrapping lines
def get_lines(parts, indent, prefix=None): def get_lines(parts, indent, prefix=None):
@ -625,20 +674,20 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
line_len = len(indent) - 1 line_len = len(indent) - 1
for part in parts: for part in parts:
if line_len + 1 + len(part) > text_width: if line_len + 1 + len(part) > text_width:
lines.append(indent + ' '.join(line)) lines.append(indent + " ".join(line))
line = [] line = []
line_len = len(indent) - 1 line_len = len(indent) - 1
line.append(part) line.append(part)
line_len += len(part) + 1 line_len += len(part) + 1
if line: if line:
lines.append(indent + ' '.join(line)) lines.append(indent + " ".join(line))
if prefix is not None: if prefix is not None:
lines[0] = lines[0][len(indent):] lines[0] = lines[0][len(indent) :]
return lines return lines
# if prog is short, follow it with optionals or positionals # if prog is short, follow it with optionals or positionals
if len(prefix) + len(prog) <= 0.75 * text_width: if len(prefix) + len(prog) <= 0.75 * text_width:
indent = ' ' * (len(prefix) + len(prog) + 1) indent = " " * (len(prefix) + len(prog) + 1)
# START TWEAK : pos_parts first, then opt_parts # START TWEAK : pos_parts first, then opt_parts
if pos_parts: if pos_parts:
lines = get_lines([prog] + pos_parts, indent, prefix) lines = get_lines([prog] + pos_parts, indent, prefix)
@ -651,7 +700,7 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
# if prog is long, put it on its own line # if prog is long, put it on its own line
else: else:
indent = ' ' * len(prefix) indent = " " * len(prefix)
parts = pos_parts + opt_parts parts = pos_parts + opt_parts
lines = get_lines(parts, indent) lines = get_lines(parts, indent)
if len(lines) > 1: if len(lines) > 1:
@ -662,7 +711,7 @@ class PositionalsFirstHelpFormatter(argparse.HelpFormatter):
lines = [prog] + lines lines = [prog] + lines
# join lines into usage # join lines into usage
usage = '\n'.join(lines) usage = "\n".join(lines)
# prefix with 'usage:' # prefix with 'usage:'
return '%s%s\n\n' % (prefix, usage) return "%s%s\n\n" % (prefix, usage)

View file

@ -16,20 +16,22 @@ from bottle import abort
from moulinette import msignals, m18n, env from moulinette import msignals, m18n, env
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
from moulinette.interfaces import ( from moulinette.interfaces import (
BaseActionsMapParser, BaseInterface, ExtendedArgumentParser, BaseActionsMapParser,
BaseInterface,
ExtendedArgumentParser,
) )
from moulinette.utils import log from moulinette.utils import log
from moulinette.utils.serialize import JSONExtendedEncoder from moulinette.utils.serialize import JSONExtendedEncoder
from moulinette.utils.text import random_ascii from moulinette.utils.text import random_ascii
logger = log.getLogger('moulinette.interface.api') logger = log.getLogger("moulinette.interface.api")
# API helpers ---------------------------------------------------------- # API helpers ----------------------------------------------------------
CSRF_TYPES = set(["text/plain", CSRF_TYPES = set(
"application/x-www-form-urlencoded", ["text/plain", "application/x-www-form-urlencoded", "multipart/form-data"]
"multipart/form-data"]) )
def is_csrf(): def is_csrf():
@ -39,7 +41,7 @@ def is_csrf():
return False return False
if request.content_type is None: if request.content_type is None:
return True return True
content_type = request.content_type.lower().split(';')[0] content_type = request.content_type.lower().split(";")[0]
if content_type not in CSRF_TYPES: if content_type not in CSRF_TYPES:
return False return False
@ -53,12 +55,14 @@ def filter_csrf(callback):
abort(403, "CSRF protection") abort(403, "CSRF protection")
else: else:
return callback(*args, **kwargs) return callback(*args, **kwargs)
return wrapper return wrapper
class LogQueues(dict): class LogQueues(dict):
"""Map of session id to queue.""" """Map of session id to queue."""
pass pass
@ -74,7 +78,7 @@ class APIQueueHandler(logging.Handler):
self.queues = LogQueues() self.queues = LogQueues()
def emit(self, record): def emit(self, record):
sid = request.get_cookie('session.id') sid = request.get_cookie("session.id")
try: try:
queue = self.queues[sid] queue = self.queues[sid]
except KeyError: except KeyError:
@ -99,13 +103,13 @@ class _HTTPArgumentParser(object):
def __init__(self): def __init__(self):
# Initialize the ArgumentParser object # Initialize the ArgumentParser object
self._parser = ExtendedArgumentParser(usage='', self._parser = ExtendedArgumentParser(
prefix_chars='@', usage="", prefix_chars="@", add_help=False
add_help=False) )
self._parser.error = self._error self._parser.error = self._error
self._positional = [] # list(arg_name) self._positional = [] # list(arg_name)
self._optional = {} # dict({arg_name: option_strings}) self._optional = {} # dict({arg_name: option_strings})
def set_defaults(self, **kwargs): def set_defaults(self, **kwargs):
return self._parser.set_defaults(**kwargs) return self._parser.set_defaults(**kwargs)
@ -113,20 +117,24 @@ class _HTTPArgumentParser(object):
def get_default(self, dest): def get_default(self, dest):
return self._parser.get_default(dest) return self._parser.get_default(dest)
def add_arguments(self, arguments, extraparser, format_arg_names=None, validate_extra=True): def add_arguments(
self, arguments, extraparser, format_arg_names=None, validate_extra=True
):
for argument_name, argument_options in arguments.items(): for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context # will adapt arguments name for cli or api context
names = format_arg_names(str(argument_name), names = format_arg_names(
argument_options.pop('full', None)) str(argument_name), argument_options.pop("full", None)
)
if "type" in argument_options: if "type" in argument_options:
argument_options['type'] = eval(argument_options['type']) argument_options["type"] = eval(argument_options["type"])
if "extra" in argument_options: if "extra" in argument_options:
extra = argument_options.pop('extra') extra = argument_options.pop("extra")
argument_dest = self.add_argument(*names, **argument_options).dest argument_dest = self.add_argument(*names, **argument_options).dest
extraparser.add_argument(self.get_default("_tid"), extraparser.add_argument(
argument_dest, extra, validate_extra) self.get_default("_tid"), argument_dest, extra, validate_extra
)
continue continue
self.add_argument(*names, **argument_options) self.add_argument(*names, **argument_options)
@ -166,12 +174,19 @@ class _HTTPArgumentParser(object):
if isinstance(v, str): if isinstance(v, str):
arg_strings.append(v) arg_strings.append(v)
else: else:
logger.warning("unsupported argument value type %r " logger.warning(
"in %s for option string %s", v, value, "unsupported argument value type %r "
option_string) "in %s for option string %s",
v,
value,
option_string,
)
else: else:
logger.warning("unsupported argument type %r for option " logger.warning(
"string %s", value, option_string) "unsupported argument type %r for option " "string %s",
value,
option_string,
)
return arg_strings return arg_strings
@ -208,14 +223,15 @@ class _ActionsMapPlugin(object):
to serve messages coming from the 'display' signal to serve messages coming from the 'display' signal
""" """
name = 'actionsmap'
name = "actionsmap"
api = 2 api = 2
def __init__(self, actionsmap, use_websocket, log_queues={}): def __init__(self, actionsmap, use_websocket, log_queues={}):
# Connect signals to handlers # Connect signals to handlers
msignals.set_handler('authenticate', self._do_authenticate) msignals.set_handler("authenticate", self._do_authenticate)
if use_websocket: if use_websocket:
msignals.set_handler('display', self._do_display) msignals.set_handler("display", self._do_display)
self.actionsmap = actionsmap self.actionsmap = actionsmap
self.use_websocket = use_websocket self.use_websocket = use_websocket
@ -237,34 +253,52 @@ class _ActionsMapPlugin(object):
def wrapper(): def wrapper():
kwargs = {} kwargs = {}
try: try:
kwargs['password'] = request.POST['password'] kwargs["password"] = request.POST["password"]
except KeyError: except KeyError:
raise HTTPBadRequestResponse("Missing password parameter") raise HTTPBadRequestResponse("Missing password parameter")
try: try:
kwargs['profile'] = request.POST['profile'] kwargs["profile"] = request.POST["profile"]
except KeyError: except KeyError:
pass pass
return callback(**kwargs) return callback(**kwargs)
return wrapper return wrapper
# Logout wrapper # Logout wrapper
def _logout(callback): def _logout(callback):
def wrapper(): def wrapper():
kwargs = {} kwargs = {}
kwargs['profile'] = request.POST.get('profile', "default") kwargs["profile"] = request.POST.get("profile", "default")
return callback(**kwargs) return callback(**kwargs)
return wrapper return wrapper
# Append authentication routes # Append authentication routes
app.route('/login', name='login', method='POST', app.route(
callback=self.login, skip=['actionsmap'], apply=_login) "/login",
app.route('/logout', name='logout', method='GET', name="login",
callback=self.logout, skip=['actionsmap'], apply=_logout) method="POST",
callback=self.login,
skip=["actionsmap"],
apply=_login,
)
app.route(
"/logout",
name="logout",
method="GET",
callback=self.logout,
skip=["actionsmap"],
apply=_logout,
)
# Append messages route # Append messages route
if self.use_websocket: if self.use_websocket:
app.route('/messages', name='messages', app.route(
callback=self.messages, skip=['actionsmap']) "/messages",
name="messages",
callback=self.messages,
skip=["actionsmap"],
)
# Append routes from the actions map # Append routes from the actions map
for (m, p) in self.actionsmap.parser.routes: for (m, p) in self.actionsmap.parser.routes:
@ -281,6 +315,7 @@ class _ActionsMapPlugin(object):
context -- An instance of Route context -- An instance of Route
""" """
def _format(value): def _format(value):
if isinstance(value, list) and len(value) == 1: if isinstance(value, list) and len(value) == 1:
return value[0] return value[0]
@ -311,11 +346,12 @@ class _ActionsMapPlugin(object):
# Process the action # Process the action
return callback((request.method, context.rule), params) return callback((request.method, context.rule), params)
return wrapper return wrapper
# Routes callbacks # Routes callbacks
def login(self, password, profile='default'): def login(self, password, profile="default"):
"""Log in to an authenticator profile """Log in to an authenticator profile
Attempt to authenticate to a given authenticator profile and Attempt to authenticate to a given authenticator profile and
@ -328,14 +364,13 @@ class _ActionsMapPlugin(object):
""" """
# Retrieve session values # Retrieve session values
s_id = request.get_cookie('session.id') or random_ascii() s_id = request.get_cookie("session.id") or random_ascii()
try: try:
s_secret = self.secrets[s_id] s_secret = self.secrets[s_id]
except KeyError: except KeyError:
s_tokens = {} s_tokens = {}
else: else:
s_tokens = request.get_cookie('session.tokens', s_tokens = request.get_cookie("session.tokens", secret=s_secret) or {}
secret=s_secret) or {}
s_new_token = random_ascii() s_new_token = random_ascii()
try: try:
@ -354,10 +389,11 @@ class _ActionsMapPlugin(object):
s_tokens[profile] = s_new_token s_tokens[profile] = s_new_token
self.secrets[s_id] = s_secret = random_ascii() self.secrets[s_id] = s_secret = random_ascii()
response.set_cookie('session.id', s_id, secure=True) response.set_cookie("session.id", s_id, secure=True)
response.set_cookie('session.tokens', s_tokens, secure=True, response.set_cookie(
secret=s_secret) "session.tokens", s_tokens, secure=True, secret=s_secret
return m18n.g('logged_in') )
return m18n.g("logged_in")
def logout(self, profile): def logout(self, profile):
"""Log out from an authenticator profile """Log out from an authenticator profile
@ -369,24 +405,23 @@ class _ActionsMapPlugin(object):
- profile -- The authenticator profile name to log out - profile -- The authenticator profile name to log out
""" """
s_id = request.get_cookie('session.id') s_id = request.get_cookie("session.id")
try: try:
# We check that there's a (signed) session.hash available # We check that there's a (signed) session.hash available
# for additional security ? # for additional security ?
# (An attacker could not craft such signed hashed ? (FIXME : need to make sure of this)) # (An attacker could not craft such signed hashed ? (FIXME : need to make sure of this))
s_secret = self.secrets[s_id] s_secret = self.secrets[s_id]
request.get_cookie('session.tokens', request.get_cookie("session.tokens", secret=s_secret, default={})[profile]
secret=s_secret, default={})[profile]
except KeyError: except KeyError:
raise HTTPUnauthorizedResponse(m18n.g('not_logged_in')) raise HTTPUnauthorizedResponse(m18n.g("not_logged_in"))
else: else:
del self.secrets[s_id] del self.secrets[s_id]
authenticator = self.actionsmap.get_authenticator_for_profile(profile) authenticator = self.actionsmap.get_authenticator_for_profile(profile)
authenticator._clean_session(s_id) authenticator._clean_session(s_id)
# TODO: Clean the session for profile only # TODO: Clean the session for profile only
# Delete cookie and clean the session # Delete cookie and clean the session
response.set_cookie('session.tokens', '', max_age=-1) response.set_cookie("session.tokens", "", max_age=-1)
return m18n.g('logged_out') return m18n.g("logged_out")
def messages(self): def messages(self):
"""Listen to the messages WebSocket stream """Listen to the messages WebSocket stream
@ -396,7 +431,7 @@ class _ActionsMapPlugin(object):
dict { style: message }. dict { style: message }.
""" """
s_id = request.get_cookie('session.id') s_id = request.get_cookie("session.id")
try: try:
queue = self.log_queues[s_id] queue = self.log_queues[s_id]
except KeyError: except KeyError:
@ -404,9 +439,9 @@ class _ActionsMapPlugin(object):
queue = Queue() queue = Queue()
self.log_queues[s_id] = queue self.log_queues[s_id] = queue
wsock = request.environ.get('wsgi.websocket') wsock = request.environ.get("wsgi.websocket")
if not wsock: if not wsock:
raise HTTPErrorResponse(m18n.g('websocket_request_expected')) raise HTTPErrorResponse(m18n.g("websocket_request_expected"))
while True: while True:
item = queue.get() item = queue.get()
@ -447,17 +482,16 @@ class _ActionsMapPlugin(object):
if isinstance(e, HTTPResponse): if isinstance(e, HTTPResponse):
raise e raise e
import traceback import traceback
tb = traceback.format_exc() tb = traceback.format_exc()
logs = {"route": _route, logs = {"route": _route, "arguments": arguments, "traceback": tb}
"arguments": arguments,
"traceback": tb}
return HTTPErrorResponse(json_encode(logs)) return HTTPErrorResponse(json_encode(logs))
else: else:
return format_for_response(ret) return format_for_response(ret)
finally: finally:
# Close opened WebSocket by putting StopIteration in the queue # Close opened WebSocket by putting StopIteration in the queue
try: try:
queue = self.log_queues[request.get_cookie('session.id')] queue = self.log_queues[request.get_cookie("session.id")]
except KeyError: except KeyError:
pass pass
else: else:
@ -471,13 +505,14 @@ class _ActionsMapPlugin(object):
Handle the core.MoulinetteSignals.authenticate signal. Handle the core.MoulinetteSignals.authenticate signal.
""" """
s_id = request.get_cookie('session.id') s_id = request.get_cookie("session.id")
try: try:
s_secret = self.secrets[s_id] s_secret = self.secrets[s_id]
s_token = request.get_cookie('session.tokens', s_token = request.get_cookie("session.tokens", secret=s_secret, default={})[
secret=s_secret, default={})[authenticator.name] authenticator.name
]
except KeyError: except KeyError:
msg = m18n.g('authentication_required') msg = m18n.g("authentication_required")
raise HTTPUnauthorizedResponse(msg) raise HTTPUnauthorizedResponse(msg)
else: else:
return authenticator(token=(s_id, s_token)) return authenticator(token=(s_id, s_token))
@ -488,7 +523,7 @@ class _ActionsMapPlugin(object):
Handle the core.MoulinetteSignals.display signal. Handle the core.MoulinetteSignals.display signal.
""" """
s_id = request.get_cookie('session.id') s_id = request.get_cookie("session.id")
try: try:
queue = self.log_queues[s_id] queue = self.log_queues[s_id]
except KeyError: except KeyError:
@ -504,50 +539,48 @@ class _ActionsMapPlugin(object):
# HTTP Responses ------------------------------------------------------- # HTTP Responses -------------------------------------------------------
class HTTPOKResponse(HTTPResponse):
def __init__(self, output=''): class HTTPOKResponse(HTTPResponse):
def __init__(self, output=""):
super(HTTPOKResponse, self).__init__(output, 200) super(HTTPOKResponse, self).__init__(output, 200)
class HTTPBadRequestResponse(HTTPResponse): class HTTPBadRequestResponse(HTTPResponse):
def __init__(self, output=""):
def __init__(self, output=''):
super(HTTPBadRequestResponse, self).__init__(output, 400) super(HTTPBadRequestResponse, self).__init__(output, 400)
class HTTPUnauthorizedResponse(HTTPResponse): class HTTPUnauthorizedResponse(HTTPResponse):
def __init__(self, output=""):
def __init__(self, output=''):
super(HTTPUnauthorizedResponse, self).__init__(output, 401) super(HTTPUnauthorizedResponse, self).__init__(output, 401)
class HTTPErrorResponse(HTTPResponse): class HTTPErrorResponse(HTTPResponse):
def __init__(self, output=""):
def __init__(self, output=''):
super(HTTPErrorResponse, self).__init__(output, 500) super(HTTPErrorResponse, self).__init__(output, 500)
def format_for_response(content): def format_for_response(content):
"""Format the resulted content of a request for the HTTP response.""" """Format the resulted content of a request for the HTTP response."""
if request.method == 'POST': if request.method == "POST":
response.status = 201 # Created response.status = 201 # Created
elif request.method == 'GET': elif request.method == "GET":
response.status = 200 # Ok response.status = 200 # Ok
else: else:
# Return empty string if no content # Return empty string if no content
if content is None or len(content) == 0: if content is None or len(content) == 0:
response.status = 204 # No Content response.status = 204 # No Content
return '' return ""
response.status = 200 response.status = 200
# Return JSON-style response # Return JSON-style response
response.content_type = 'application/json' response.content_type = "application/json"
return json_encode(content, cls=JSONExtendedEncoder) return json_encode(content, cls=JSONExtendedEncoder)
# API Classes Implementation ------------------------------------------- # API Classes Implementation -------------------------------------------
class ActionsMapParser(BaseActionsMapParser): class ActionsMapParser(BaseActionsMapParser):
"""Actions map's Parser for the API """Actions map's Parser for the API
@ -561,7 +594,7 @@ class ActionsMapParser(BaseActionsMapParser):
super(ActionsMapParser, self).__init__(parent) super(ActionsMapParser, self).__init__(parent)
self._parsers = {} # dict({(method, path): _HTTPArgumentParser}) self._parsers = {} # dict({(method, path): _HTTPArgumentParser})
self._route_re = re.compile(r'(GET|POST|PUT|DELETE) (/\S+)') self._route_re = re.compile(r"(GET|POST|PUT|DELETE) (/\S+)")
@property @property
def routes(self): def routes(self):
@ -570,19 +603,19 @@ class ActionsMapParser(BaseActionsMapParser):
# Implement virtual properties # Implement virtual properties
interface = 'api' interface = "api"
# Implement virtual methods # Implement virtual methods
@staticmethod @staticmethod
def format_arg_names(name, full): def format_arg_names(name, full):
if name[0] != '-': if name[0] != "-":
return [name] return [name]
if full: if full:
return [full.replace('--', '@', 1)] return [full.replace("--", "@", 1)]
if name.startswith('--'): if name.startswith("--"):
return [name.replace('--', '@', 1)] return [name.replace("--", "@", 1)]
return [name.replace('-', '@', 1)] return [name.replace("-", "@", 1)]
def add_category_parser(self, name, **kwargs): def add_category_parser(self, name, **kwargs):
return self return self
@ -611,8 +644,9 @@ class ActionsMapParser(BaseActionsMapParser):
try: try:
keys.append(self._extract_route(r)) keys.append(self._extract_route(r))
except ValueError as e: except ValueError as e:
logger.warning("cannot add api route '%s' for " logger.warning(
"action %s: %s", r, tid, e) "cannot add api route '%s' for " "action %s: %s", r, tid, e
)
continue continue
if len(keys) == 0: if len(keys) == 0:
raise ValueError("no valid api route found") raise ValueError("no valid api route found")
@ -631,7 +665,7 @@ class ActionsMapParser(BaseActionsMapParser):
try: try:
# Retrieve the tid for the route # Retrieve the tid for the route
tid, _ = self._parsers[route] tid, _ = self._parsers[route]
if not self.get_conf(tid, 'authenticate'): if not self.get_conf(tid, "authenticate"):
return False return False
else: else:
# TODO: In the future, we could make the authentication # TODO: In the future, we could make the authentication
@ -640,10 +674,10 @@ class ActionsMapParser(BaseActionsMapParser):
# auth with some custom auth system to access some # auth with some custom auth system to access some
# data with something like : # data with something like :
# return self.get_conf(tid, 'authenticator') # return self.get_conf(tid, 'authenticator')
return 'default' return "default"
except KeyError: except KeyError:
logger.error("no argument parser found for route '%s'", route) logger.error("no argument parser found for route '%s'", route)
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
def parse_args(self, args, route, **kwargs): def parse_args(self, args, route, **kwargs):
"""Parse arguments """Parse arguments
@ -657,7 +691,7 @@ class ActionsMapParser(BaseActionsMapParser):
_, parser = self._parsers[route] _, parser = self._parsers[route]
except KeyError: except KeyError:
logger.error("no argument parser found for route '%s'", route) logger.error("no argument parser found for route '%s'", route)
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
ret = argparse.Namespace() ret = argparse.Namespace()
# TODO: Catch errors? # TODO: Catch errors?
@ -705,8 +739,7 @@ class Interface(BaseInterface):
""" """
def __init__(self, actionsmap, routes={}, use_websocket=True, def __init__(self, actionsmap, routes={}, use_websocket=True, log_queues=None):
log_queues=None):
self.use_websocket = use_websocket self.use_websocket = use_websocket
# Attempt to retrieve log queues from an APIQueueHandler # Attempt to retrieve log queues from an APIQueueHandler
@ -721,14 +754,15 @@ class Interface(BaseInterface):
# Wrapper which sets proper header # Wrapper which sets proper header
def apiheader(callback): def apiheader(callback):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
response.set_header('Access-Control-Allow-Origin', '*') response.set_header("Access-Control-Allow-Origin", "*")
return callback(*args, **kwargs) return callback(*args, **kwargs)
return wrapper return wrapper
# Attempt to retrieve and set locale # Attempt to retrieve and set locale
def api18n(callback): def api18n(callback):
try: try:
locale = request.params.pop('locale') locale = request.params.pop("locale")
except KeyError: except KeyError:
locale = m18n.default_locale locale = m18n.default_locale
m18n.set_locale(locale) m18n.set_locale(locale)
@ -741,17 +775,17 @@ class Interface(BaseInterface):
app.install(_ActionsMapPlugin(actionsmap, use_websocket, log_queues)) app.install(_ActionsMapPlugin(actionsmap, use_websocket, log_queues))
# Append default routes # Append default routes
# app.route(['/api', '/api/<category:re:[a-z]+>'], method='GET', # app.route(['/api', '/api/<category:re:[a-z]+>'], method='GET',
# callback=self.doc, skip=['actionsmap']) # callback=self.doc, skip=['actionsmap'])
# Append additional routes # Append additional routes
# TODO: Add optional authentication to those routes? # TODO: Add optional authentication to those routes?
for (m, p), c in routes.items(): for (m, p), c in routes.items():
app.route(p, method=m, callback=c, skip=['actionsmap']) app.route(p, method=m, callback=c, skip=["actionsmap"])
self._app = app self._app = app
def run(self, host='localhost', port=80): def run(self, host="localhost", port=80):
"""Run the moulinette """Run the moulinette
Start a server instance on the given port to serve moulinette Start a server instance on the given port to serve moulinette
@ -762,25 +796,29 @@ class Interface(BaseInterface):
- port -- Server port to bind to - port -- Server port to bind to
""" """
logger.debug("starting the server instance in %s:%d with websocket=%s", logger.debug(
host, port, self.use_websocket) "starting the server instance in %s:%d with websocket=%s",
host,
port,
self.use_websocket,
)
try: try:
if self.use_websocket: if self.use_websocket:
from gevent.pywsgi import WSGIServer from gevent.pywsgi import WSGIServer
from geventwebsocket.handler import WebSocketHandler from geventwebsocket.handler import WebSocketHandler
server = WSGIServer((host, port), self._app, server = WSGIServer(
handler_class=WebSocketHandler) (host, port), self._app, handler_class=WebSocketHandler
)
server.serve_forever() server.serve_forever()
else: else:
run(self._app, host=host, port=port) run(self._app, host=host, port=port)
except IOError as e: except IOError as e:
logger.exception("unable to start the server instance on %s:%d", logger.exception("unable to start the server instance on %s:%d", host, port)
host, port)
if e.args[0] == errno.EADDRINUSE: if e.args[0] == errno.EADDRINUSE:
raise MoulinetteError('server_already_running') raise MoulinetteError("server_already_running")
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
# Routes handlers # Routes handlers
@ -792,14 +830,14 @@ class Interface(BaseInterface):
category -- Name of the category category -- Name of the category
""" """
DATA_DIR = env()['DATA_DIR'] DATA_DIR = env()["DATA_DIR"]
if category is None: if category is None:
with open('%s/../doc/resources.json' % DATA_DIR) as f: with open("%s/../doc/resources.json" % DATA_DIR) as f:
return f.read() return f.read()
try: try:
with open('%s/../doc/%s.json' % (DATA_DIR, category)) as f: with open("%s/../doc/%s.json" % (DATA_DIR, category)) as f:
return f.read() return f.read()
except IOError: except IOError:
return None return None

View file

@ -15,27 +15,29 @@ import argcomplete
from moulinette import msignals, m18n from moulinette import msignals, m18n
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
from moulinette.interfaces import ( from moulinette.interfaces import (
BaseActionsMapParser, BaseInterface, ExtendedArgumentParser, BaseActionsMapParser,
BaseInterface,
ExtendedArgumentParser,
) )
from moulinette.utils import log from moulinette.utils import log
logger = log.getLogger('moulinette.cli') logger = log.getLogger("moulinette.cli")
# CLI helpers ---------------------------------------------------------- # CLI helpers ----------------------------------------------------------
CLI_COLOR_TEMPLATE = '\033[{:d}m\033[1m' CLI_COLOR_TEMPLATE = "\033[{:d}m\033[1m"
END_CLI_COLOR = '\033[m' END_CLI_COLOR = "\033[m"
colors_codes = { colors_codes = {
'red': CLI_COLOR_TEMPLATE.format(31), "red": CLI_COLOR_TEMPLATE.format(31),
'green': CLI_COLOR_TEMPLATE.format(32), "green": CLI_COLOR_TEMPLATE.format(32),
'yellow': CLI_COLOR_TEMPLATE.format(33), "yellow": CLI_COLOR_TEMPLATE.format(33),
'blue': CLI_COLOR_TEMPLATE.format(34), "blue": CLI_COLOR_TEMPLATE.format(34),
'purple': CLI_COLOR_TEMPLATE.format(35), "purple": CLI_COLOR_TEMPLATE.format(35),
'cyan': CLI_COLOR_TEMPLATE.format(36), "cyan": CLI_COLOR_TEMPLATE.format(36),
'white': CLI_COLOR_TEMPLATE.format(37), "white": CLI_COLOR_TEMPLATE.format(37),
} }
@ -50,7 +52,7 @@ def colorize(astr, color):
""" """
if os.isatty(1): if os.isatty(1):
return '{:s}{:s}{:s}'.format(colors_codes[color], astr, END_CLI_COLOR) return "{:s}{:s}{:s}".format(colors_codes[color], astr, END_CLI_COLOR)
else: else:
return astr return astr
@ -91,7 +93,7 @@ def plain_print_dict(d, depth=0):
plain_print_dict(v, depth + 1) plain_print_dict(v, depth + 1)
else: else:
if isinstance(d, unicode): if isinstance(d, unicode):
d = d.encode('utf-8') d = d.encode("utf-8")
print(d) print(d)
@ -107,7 +109,7 @@ def pretty_date(_date):
nowtz = nowtz.replace(tzinfo=pytz.utc) nowtz = nowtz.replace(tzinfo=pytz.utc)
offsetHour = nowutc - nowtz offsetHour = nowutc - nowtz
offsetHour = int(round(offsetHour.total_seconds() / 3600)) offsetHour = int(round(offsetHour.total_seconds() / 3600))
localtz = 'Etc/GMT%+d' % offsetHour localtz = "Etc/GMT%+d" % offsetHour
# Transform naive date into UTC date # Transform naive date into UTC date
if _date.tzinfo is None: if _date.tzinfo is None:
@ -136,7 +138,7 @@ def pretty_print_dict(d, depth=0):
keys = sorted(keys) keys = sorted(keys)
for k in keys: for k in keys:
v = d[k] v = d[k]
k = colorize(str(k), 'purple') k = colorize(str(k), "purple")
if isinstance(v, (tuple, set)): if isinstance(v, (tuple, set)):
v = list(v) v = list(v)
if isinstance(v, list) and len(v) == 1: if isinstance(v, list) and len(v) == 1:
@ -153,13 +155,13 @@ def pretty_print_dict(d, depth=0):
pretty_print_dict({key: value}, depth + 1) pretty_print_dict({key: value}, depth + 1)
else: else:
if isinstance(value, unicode): if isinstance(value, unicode):
value = value.encode('utf-8') value = value.encode("utf-8")
elif isinstance(v, date): elif isinstance(v, date):
v = pretty_date(v) v = pretty_date(v)
print("{:s}- {}".format(" " * (depth + 1), value)) print("{:s}- {}".format(" " * (depth + 1), value))
else: else:
if isinstance(v, unicode): if isinstance(v, unicode):
v = v.encode('utf-8') v = v.encode("utf-8")
elif isinstance(v, date): elif isinstance(v, date):
v = pretty_date(v) v = pretty_date(v)
print("{:s}{}: {}".format(" " * depth, k, v)) print("{:s}{}: {}".format(" " * depth, k, v))
@ -169,12 +171,13 @@ def get_locale():
"""Return current user locale""" """Return current user locale"""
lang = locale.getdefaultlocale()[0] lang = locale.getdefaultlocale()[0]
if not lang: if not lang:
return '' return ""
return lang[:2] return lang[:2]
# CLI Classes Implementation ------------------------------------------- # CLI Classes Implementation -------------------------------------------
class TTYHandler(logging.StreamHandler): class TTYHandler(logging.StreamHandler):
"""TTY log handler """TTY log handler
@ -192,17 +195,18 @@ class TTYHandler(logging.StreamHandler):
stderr. Otherwise, they are sent to stdout. stderr. Otherwise, they are sent to stdout.
""" """
LEVELS_COLOR = { LEVELS_COLOR = {
log.NOTSET: 'white', log.NOTSET: "white",
log.DEBUG: 'white', log.DEBUG: "white",
log.INFO: 'cyan', log.INFO: "cyan",
log.SUCCESS: 'green', log.SUCCESS: "green",
log.WARNING: 'yellow', log.WARNING: "yellow",
log.ERROR: 'red', log.ERROR: "red",
log.CRITICAL: 'red', log.CRITICAL: "red",
} }
def __init__(self, message_key='fmessage'): def __init__(self, message_key="fmessage"):
logging.StreamHandler.__init__(self) logging.StreamHandler.__init__(self)
self.message_key = message_key self.message_key = message_key
@ -210,16 +214,15 @@ class TTYHandler(logging.StreamHandler):
"""Enhance message with level and colors if supported.""" """Enhance message with level and colors if supported."""
msg = record.getMessage() msg = record.getMessage()
if self.supports_color(): if self.supports_color():
level = '' level = ""
if self.level <= log.DEBUG: if self.level <= log.DEBUG:
# add level name before message # add level name before message
level = '%s ' % record.levelname level = "%s " % record.levelname
elif record.levelname in ['SUCCESS', 'WARNING', 'ERROR', 'INFO']: elif record.levelname in ["SUCCESS", "WARNING", "ERROR", "INFO"]:
# add translated level name before message # add translated level name before message
level = '%s ' % m18n.g(record.levelname.lower()) level = "%s " % m18n.g(record.levelname.lower())
color = self.LEVELS_COLOR.get(record.levelno, 'white') color = self.LEVELS_COLOR.get(record.levelno, "white")
msg = '{0}{1}{2}{3}'.format( msg = "{0}{1}{2}{3}".format(colors_codes[color], level, END_CLI_COLOR, msg)
colors_codes[color], level, END_CLI_COLOR, msg)
if self.formatter: if self.formatter:
# use user-defined formatter # use user-defined formatter
record.__dict__[self.message_key] = msg record.__dict__[self.message_key] = msg
@ -236,7 +239,7 @@ class TTYHandler(logging.StreamHandler):
def supports_color(self): def supports_color(self):
"""Check whether current stream supports color.""" """Check whether current stream supports color."""
if hasattr(self.stream, 'isatty') and self.stream.isatty(): if hasattr(self.stream, "isatty") and self.stream.isatty():
return True return True
return False return False
@ -256,12 +259,13 @@ class ActionsMapParser(BaseActionsMapParser):
""" """
def __init__(self, parent=None, parser=None, subparser_kwargs=None, def __init__(
top_parser=None, **kwargs): self, parent=None, parser=None, subparser_kwargs=None, top_parser=None, **kwargs
):
super(ActionsMapParser, self).__init__(parent) super(ActionsMapParser, self).__init__(parent)
if subparser_kwargs is None: if subparser_kwargs is None:
subparser_kwargs = {'title': "categories", 'required': False} subparser_kwargs = {"title": "categories", "required": False}
self._parser = parser or ExtendedArgumentParser() self._parser = parser or ExtendedArgumentParser()
self._subparsers = self._parser.add_subparsers(**subparser_kwargs) self._subparsers = self._parser.add_subparsers(**subparser_kwargs)
@ -277,13 +281,13 @@ class ActionsMapParser(BaseActionsMapParser):
# Implement virtual properties # Implement virtual properties
interface = 'cli' interface = "cli"
# Implement virtual methods # Implement virtual methods
@staticmethod @staticmethod
def format_arg_names(name, full): def format_arg_names(name, full):
if name[0] == '-' and full: if name[0] == "-" and full:
return [name, full] return [name, full]
return [name] return [name]
@ -300,13 +304,10 @@ class ActionsMapParser(BaseActionsMapParser):
A new ActionsMapParser object for the category A new ActionsMapParser object for the category
""" """
parser = self._subparsers.add_parser(name, parser = self._subparsers.add_parser(
description=category_help, name, description=category_help, help=category_help, **kwargs
help=category_help, )
**kwargs) return self.__class__(self, parser, {"title": "subcommands", "required": True})
return self.__class__(self, parser, {
'title': "subcommands", 'required': True
})
def add_subcategory_parser(self, name, subcategory_help=None, **kwargs): def add_subcategory_parser(self, name, subcategory_help=None, **kwargs):
"""Add a parser for a subcategory """Add a parser for a subcategory
@ -318,17 +319,24 @@ class ActionsMapParser(BaseActionsMapParser):
A new ActionsMapParser object for the category A new ActionsMapParser object for the category
""" """
parser = self._subparsers.add_parser(name, parser = self._subparsers.add_parser(
type_="subcategory", name,
description=subcategory_help, type_="subcategory",
help=subcategory_help, description=subcategory_help,
**kwargs) help=subcategory_help,
return self.__class__(self, parser, { **kwargs
'title': "actions", 'required': True )
}) return self.__class__(self, parser, {"title": "actions", "required": True})
def add_action_parser(self, name, tid, action_help=None, deprecated=False, def add_action_parser(
deprecated_alias=[], **kwargs): self,
name,
tid,
action_help=None,
deprecated=False,
deprecated_alias=[],
**kwargs
):
"""Add a parser for an action """Add a parser for an action
Keyword arguments: Keyword arguments:
@ -340,18 +348,21 @@ class ActionsMapParser(BaseActionsMapParser):
A new ExtendedArgumentParser object for the action A new ExtendedArgumentParser object for the action
""" """
return self._subparsers.add_parser(name, return self._subparsers.add_parser(
type_="action", name,
help=action_help, type_="action",
description=action_help, help=action_help,
deprecated=deprecated, description=action_help,
deprecated_alias=deprecated_alias) deprecated=deprecated,
deprecated_alias=deprecated_alias,
)
def add_global_arguments(self, arguments): def add_global_arguments(self, arguments):
for argument_name, argument_options in arguments.items(): for argument_name, argument_options in arguments.items():
# will adapt arguments name for cli or api context # will adapt arguments name for cli or api context
names = self.format_arg_names(str(argument_name), names = self.format_arg_names(
argument_options.pop('full', None)) str(argument_name), argument_options.pop("full", None)
)
self.global_parser.add_argument(*names, **argument_options) self.global_parser.add_argument(*names, **argument_options)
@ -363,12 +374,12 @@ class ActionsMapParser(BaseActionsMapParser):
except SystemExit: except SystemExit:
raise raise
except: except:
logger.exception("unable to parse arguments '%s'", ' '.join(args)) logger.exception("unable to parse arguments '%s'", " ".join(args))
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
tid = getattr(ret, '_tid', None) tid = getattr(ret, "_tid", None)
if self.get_conf(tid, 'authenticate'): if self.get_conf(tid, "authenticate"):
return self.get_conf(tid, 'authenticator') return self.get_conf(tid, "authenticator")
else: else:
return False return False
@ -378,10 +389,10 @@ class ActionsMapParser(BaseActionsMapParser):
except SystemExit: except SystemExit:
raise raise
except: except:
logger.exception("unable to parse arguments '%s'", ' '.join(args)) logger.exception("unable to parse arguments '%s'", " ".join(args))
raise MoulinetteError('error_see_log') raise MoulinetteError("error_see_log")
else: else:
self.prepare_action_namespace(getattr(ret, '_tid', None), ret) self.prepare_action_namespace(getattr(ret, "_tid", None), ret)
self._parser.dequeue_callbacks(ret) self._parser.dequeue_callbacks(ret)
return ret return ret
@ -403,10 +414,10 @@ class Interface(BaseInterface):
m18n.set_locale(get_locale()) m18n.set_locale(get_locale())
# Connect signals to handlers # Connect signals to handlers
msignals.set_handler('display', self._do_display) msignals.set_handler("display", self._do_display)
if os.isatty(1): if os.isatty(1):
msignals.set_handler('authenticate', self._do_authenticate) msignals.set_handler("authenticate", self._do_authenticate)
msignals.set_handler('prompt', self._do_prompt) msignals.set_handler("prompt", self._do_prompt)
self.actionsmap = actionsmap self.actionsmap = actionsmap
@ -426,30 +437,30 @@ class Interface(BaseInterface):
- timeout -- Number of seconds before this command will timeout because it can't acquire the lock (meaning that another command is currently running), by default there is no timeout and the command will wait until it can get the lock - timeout -- Number of seconds before this command will timeout because it can't acquire the lock (meaning that another command is currently running), by default there is no timeout and the command will wait until it can get the lock
""" """
if output_as and output_as not in ['json', 'plain', 'none']: if output_as and output_as not in ["json", "plain", "none"]:
raise MoulinetteError('invalid_usage') raise MoulinetteError("invalid_usage")
# auto-complete # auto-complete
argcomplete.autocomplete(self.actionsmap.parser._parser) argcomplete.autocomplete(self.actionsmap.parser._parser)
# Set handler for authentication # Set handler for authentication
if password: if password:
msignals.set_handler('authenticate', msignals.set_handler("authenticate", lambda a: a(password=password))
lambda a: a(password=password))
try: try:
ret = self.actionsmap.process(args, timeout=timeout) ret = self.actionsmap.process(args, timeout=timeout)
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
raise MoulinetteError('operation_interrupted') raise MoulinetteError("operation_interrupted")
if ret is None or output_as == 'none': if ret is None or output_as == "none":
return return
# Format and print result # Format and print result
if output_as: if output_as:
if output_as == 'json': if output_as == "json":
import json import json
from moulinette.utils.serialize import JSONExtendedEncoder from moulinette.utils.serialize import JSONExtendedEncoder
print(json.dumps(ret, cls=JSONExtendedEncoder)) print(json.dumps(ret, cls=JSONExtendedEncoder))
else: else:
plain_print_dict(ret) plain_print_dict(ret)
@ -468,11 +479,10 @@ class Interface(BaseInterface):
""" """
# TODO: Allow token authentication? # TODO: Allow token authentication?
help = authenticator.extra.get("help") help = authenticator.extra.get("help")
msg = m18n.n(help) if help else m18n.g('password') msg = m18n.n(help) if help else m18n.g("password")
return authenticator(password=self._do_prompt(msg, True, False, return authenticator(password=self._do_prompt(msg, True, False, color="yellow"))
color='yellow'))
def _do_prompt(self, message, is_password, confirm, color='blue'): def _do_prompt(self, message, is_password, confirm, color="blue"):
"""Prompt for a value """Prompt for a value
Handle the core.MoulinetteSignals.prompt signal. Handle the core.MoulinetteSignals.prompt signal.
@ -482,16 +492,15 @@ class Interface(BaseInterface):
""" """
if is_password: if is_password:
prompt = lambda m: getpass.getpass(colorize(m18n.g('colon', m), prompt = lambda m: getpass.getpass(colorize(m18n.g("colon", m), color))
color))
else: else:
prompt = lambda m: raw_input(colorize(m18n.g('colon', m), color)) prompt = lambda m: raw_input(colorize(m18n.g("colon", m), color))
value = prompt(message) value = prompt(message)
if confirm: if confirm:
m = message[0].lower() + message[1:] m = message[0].lower() + message[1:]
if prompt(m18n.g('confirm', prompt=m)) != value: if prompt(m18n.g("confirm", prompt=m)) != value:
raise MoulinetteError('values_mismatch') raise MoulinetteError("values_mismatch")
return value return value
@ -502,12 +511,12 @@ class Interface(BaseInterface):
""" """
if isinstance(message, unicode): if isinstance(message, unicode):
message = message.encode('utf-8') message = message.encode("utf-8")
if style == 'success': if style == "success":
print('{} {}'.format(colorize(m18n.g('success'), 'green'), message)) print("{} {}".format(colorize(m18n.g("success"), "green"), message))
elif style == 'warning': elif style == "warning":
print('{} {}'.format(colorize(m18n.g('warning'), 'yellow'), message)) print("{} {}".format(colorize(m18n.g("warning"), "yellow"), message))
elif style == 'error': elif style == "error":
print('{} {}'.format(colorize(m18n.g('error'), 'red'), message)) print("{} {}".format(colorize(m18n.g("error"), "red"), message))
else: else:
print(message) print(message)

View file

@ -22,21 +22,25 @@ def read_file(file_path):
Keyword argument: Keyword argument:
file_path -- Path to the text file file_path -- Path to the text file
""" """
assert isinstance(file_path, basestring), "Error: file_path '%s' should be a string but is of type '%s' instead" % (file_path, type(file_path)) assert isinstance(file_path, basestring), (
"Error: file_path '%s' should be a string but is of type '%s' instead"
% (file_path, type(file_path))
)
# Check file exists # Check file exists
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
raise MoulinetteError('file_not_exist', path=file_path) raise MoulinetteError("file_not_exist", path=file_path)
# Open file and read content # Open file and read content
try: try:
with open(file_path, "r") as f: with open(file_path, "r") as f:
file_content = f.read() file_content = f.read()
except IOError as e: except IOError as e:
raise MoulinetteError('cannot_open_file', file=file_path, error=str(e)) raise MoulinetteError("cannot_open_file", file=file_path, error=str(e))
except Exception: except Exception:
raise MoulinetteError('unknown_error_reading_file', raise MoulinetteError(
file=file_path, error=str(e)) "unknown_error_reading_file", file=file_path, error=str(e)
)
return file_content return file_content
@ -56,7 +60,7 @@ def read_json(file_path):
try: try:
loaded_json = json.loads(file_content) loaded_json = json.loads(file_content)
except ValueError as e: except ValueError as e:
raise MoulinetteError('corrupted_json', ressource=file_path, error=str(e)) raise MoulinetteError("corrupted_json", ressource=file_path, error=str(e))
return loaded_json return loaded_json
@ -76,7 +80,7 @@ def read_yaml(file_path):
try: try:
loaded_yaml = yaml.safe_load(file_content) loaded_yaml = yaml.safe_load(file_content)
except Exception as e: except Exception as e:
raise MoulinetteError('corrupted_yaml', ressource=file_path, error=str(e)) raise MoulinetteError("corrupted_yaml", ressource=file_path, error=str(e))
return loaded_yaml return loaded_yaml
@ -96,9 +100,9 @@ def read_toml(file_path):
try: try:
loaded_toml = toml.loads(file_content, _dict=OrderedDict) loaded_toml = toml.loads(file_content, _dict=OrderedDict)
except Exception as e: except Exception as e:
raise MoulinetteError(errno.EINVAL, raise MoulinetteError(
m18n.g('corrupted_toml', errno.EINVAL, m18n.g("corrupted_toml", ressource=file_path, error=str(e))
ressource=file_path, error=str(e))) )
return loaded_toml return loaded_toml
@ -129,10 +133,11 @@ def read_ldif(file_path, filtred_entries=[]):
parser = LDIFPar(f) parser = LDIFPar(f)
parser.parse() parser.parse()
except IOError as e: except IOError as e:
raise MoulinetteError('cannot_open_file', file=file_path, error=str(e)) raise MoulinetteError("cannot_open_file", file=file_path, error=str(e))
except Exception as e: except Exception as e:
raise MoulinetteError('unknown_error_reading_file', raise MoulinetteError(
file=file_path, error=str(e)) "unknown_error_reading_file", file=file_path, error=str(e)
)
return parser.all_records return parser.all_records
@ -148,23 +153,34 @@ def write_to_file(file_path, data, file_mode="w"):
file_mode -- Mode used when writing the file. Option meant to be used file_mode -- Mode used when writing the file. Option meant to be used
by append_to_file to avoid duplicating the code of this function. by append_to_file to avoid duplicating the code of this function.
""" """
assert isinstance(data, basestring) or isinstance(data, list), "Error: data '%s' should be either a string or a list but is of type '%s'" % (data, type(data)) assert isinstance(data, basestring) or isinstance(data, list), (
assert not os.path.isdir(file_path), "Error: file_path '%s' point to a dir, it should be a file" % file_path "Error: data '%s' should be either a string or a list but is of type '%s'"
assert os.path.isdir(os.path.dirname(file_path)), "Error: the path ('%s') base dir ('%s') is not a dir" % (file_path, os.path.dirname(file_path)) % (data, type(data))
)
assert not os.path.isdir(file_path), (
"Error: file_path '%s' point to a dir, it should be a file" % file_path
)
assert os.path.isdir(os.path.dirname(file_path)), (
"Error: the path ('%s') base dir ('%s') is not a dir"
% (file_path, os.path.dirname(file_path))
)
# If data is a list, check elements are strings and build a single string # If data is a list, check elements are strings and build a single string
if not isinstance(data, basestring): if not isinstance(data, basestring):
for element in data: for element in data:
assert isinstance(element, basestring), "Error: element '%s' should be a string but is of type '%s' instead" % (element, type(element)) assert isinstance(element, basestring), (
data = '\n'.join(data) "Error: element '%s' should be a string but is of type '%s' instead"
% (element, type(element))
)
data = "\n".join(data)
try: try:
with open(file_path, file_mode) as f: with open(file_path, file_mode) as f:
f.write(data) f.write(data)
except IOError as e: except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e)) raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e: except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e)) raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def append_to_file(file_path, data): def append_to_file(file_path, data):
@ -189,19 +205,30 @@ def write_to_json(file_path, data):
""" """
# Assumptions # Assumptions
assert isinstance(file_path, basestring), "Error: file_path '%s' should be a string but is of type '%s' instead" % (file_path, type(file_path)) assert isinstance(file_path, basestring), (
assert isinstance(data, dict) or isinstance(data, list), "Error: data '%s' should be a dict or a list but is of type '%s' instead" % (data, type(data)) "Error: file_path '%s' should be a string but is of type '%s' instead"
assert not os.path.isdir(file_path), "Error: file_path '%s' point to a dir, it should be a file" % file_path % (file_path, type(file_path))
assert os.path.isdir(os.path.dirname(file_path)), "Error: the path ('%s') base dir ('%s') is not a dir" % (file_path, os.path.dirname(file_path)) )
assert isinstance(data, dict) or isinstance(data, list), (
"Error: data '%s' should be a dict or a list but is of type '%s' instead"
% (data, type(data))
)
assert not os.path.isdir(file_path), (
"Error: file_path '%s' point to a dir, it should be a file" % file_path
)
assert os.path.isdir(os.path.dirname(file_path)), (
"Error: the path ('%s') base dir ('%s') is not a dir"
% (file_path, os.path.dirname(file_path))
)
# Write dict to file # Write dict to file
try: try:
with open(file_path, "w") as f: with open(file_path, "w") as f:
json.dump(data, f) json.dump(data, f)
except IOError as e: except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e)) raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e: except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e)) raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def write_to_yaml(file_path, data): def write_to_yaml(file_path, data):
@ -223,9 +250,9 @@ def write_to_yaml(file_path, data):
with open(file_path, "w") as f: with open(file_path, "w") as f:
yaml.safe_dump(data, f, default_flow_style=False) yaml.safe_dump(data, f, default_flow_style=False)
except IOError as e: except IOError as e:
raise MoulinetteError('cannot_write_file', file=file_path, error=str(e)) raise MoulinetteError("cannot_write_file", file=file_path, error=str(e))
except Exception as e: except Exception as e:
raise MoulinetteError('error_writing_file', file=file_path, error=str(e)) raise MoulinetteError("error_writing_file", file=file_path, error=str(e))
def mkdir(path, mode=0o777, parents=False, uid=None, gid=None, force=False): def mkdir(path, mode=0o777, parents=False, uid=None, gid=None, force=False):
@ -245,7 +272,7 @@ def mkdir(path, mode=0o777, parents=False, uid=None, gid=None, force=False):
""" """
if os.path.exists(path) and not force: if os.path.exists(path) and not force:
raise OSError(errno.EEXIST, m18n.g('folder_exists', path=path)) raise OSError(errno.EEXIST, m18n.g("folder_exists", path=path))
if parents: if parents:
# Create parents directories as needed # Create parents directories as needed
@ -290,14 +317,14 @@ def chown(path, uid=None, gid=None, recursive=False):
try: try:
uid = getpwnam(uid).pw_uid uid = getpwnam(uid).pw_uid
except KeyError: except KeyError:
raise MoulinetteError('unknown_user', user=uid) raise MoulinetteError("unknown_user", user=uid)
elif uid is None: elif uid is None:
uid = -1 uid = -1
if isinstance(gid, basestring): if isinstance(gid, basestring):
try: try:
gid = grp.getgrnam(gid).gr_gid gid = grp.getgrnam(gid).gr_gid
except KeyError: except KeyError:
raise MoulinetteError('unknown_group', group=gid) raise MoulinetteError("unknown_group", group=gid)
elif gid is None: elif gid is None:
gid = -1 gid = -1
@ -310,7 +337,9 @@ def chown(path, uid=None, gid=None, recursive=False):
for f in files: for f in files:
os.chown(os.path.join(root, f), uid, gid) os.chown(os.path.join(root, f), uid, gid)
except Exception as e: except Exception as e:
raise MoulinetteError('error_changing_file_permissions', path=path, error=str(e)) raise MoulinetteError(
"error_changing_file_permissions", path=path, error=str(e)
)
def chmod(path, mode, fmode=None, recursive=False): def chmod(path, mode, fmode=None, recursive=False):
@ -334,7 +363,9 @@ def chmod(path, mode, fmode=None, recursive=False):
for f in files: for f in files:
os.chmod(os.path.join(root, f), fmode) os.chmod(os.path.join(root, f), fmode)
except Exception as e: except Exception as e:
raise MoulinetteError('error_changing_file_permissions', path=path, error=str(e)) raise MoulinetteError(
"error_changing_file_permissions", path=path, error=str(e)
)
def rm(path, recursive=False, force=False): def rm(path, recursive=False, force=False):
@ -353,4 +384,4 @@ def rm(path, recursive=False, force=False):
os.remove(path) os.remove(path)
except OSError as e: except OSError as e:
if not force: if not force:
raise MoulinetteError('error_removing', path=path, error=str(e)) raise MoulinetteError("error_removing", path=path, error=str(e))

View file

@ -3,8 +3,18 @@ import logging
# import all constants because other modules try to import them from this # import all constants because other modules try to import them from this
# module because SUCCESS is defined in this module # module because SUCCESS is defined in this module
from logging import (addLevelName, setLoggerClass, Logger, getLogger, NOTSET, # noqa from logging import (
DEBUG, INFO, WARNING, ERROR, CRITICAL) addLevelName,
setLoggerClass,
Logger,
getLogger,
NOTSET, # noqa
DEBUG,
INFO,
WARNING,
ERROR,
CRITICAL,
)
# Global configuration and functions ----------------------------------- # Global configuration and functions -----------------------------------
@ -12,27 +22,20 @@ from logging import (addLevelName, setLoggerClass, Logger, getLogger, NOTSET, #
SUCCESS = 25 SUCCESS = 25
DEFAULT_LOGGING = { DEFAULT_LOGGING = {
'version': 1, "version": 1,
'disable_existing_loggers': False, "disable_existing_loggers": False,
'formatters': { "formatters": {
'simple': { "simple": {"format": "%(asctime)-15s %(levelname)-8s %(name)s - %(message)s"},
'format': '%(asctime)-15s %(levelname)-8s %(name)s - %(message)s' },
}, "handlers": {
}, "console": {
'handlers': { "level": "DEBUG",
'console': { "formatter": "simple",
'level': 'DEBUG', "class": "logging.StreamHandler",
'formatter': 'simple', "stream": "ext://sys.stdout",
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stdout',
},
},
'loggers': {
'moulinette': {
'level': 'DEBUG',
'handlers': ['console'],
}, },
}, },
"loggers": {"moulinette": {"level": "DEBUG", "handlers": ["console"],},},
} }
@ -46,7 +49,7 @@ def configure_logging(logging_config=None):
from logging.config import dictConfig from logging.config import dictConfig
# add custom logging level and class # add custom logging level and class
addLevelName(SUCCESS, 'SUCCESS') addLevelName(SUCCESS, "SUCCESS")
setLoggerClass(MoulinetteLogger) setLoggerClass(MoulinetteLogger)
# load configuration from dict # load configuration from dict
@ -65,7 +68,7 @@ def getHandlersByClass(classinfo, limit=0):
return o return o
handlers.append(o) handlers.append(o)
if limit != 0 and len(handlers) > limit: if limit != 0 and len(handlers) > limit:
return handlers[:limit - 1] return handlers[: limit - 1]
return handlers return handlers
@ -79,6 +82,7 @@ class MoulinetteLogger(Logger):
LogRecord extra and can be used with the ActionFilter. LogRecord extra and can be used with the ActionFilter.
""" """
action_id = None action_id = None
def success(self, msg, *args, **kwargs): def success(self, msg, *args, **kwargs):
@ -105,11 +109,11 @@ class MoulinetteLogger(Logger):
def _log(self, *args, **kwargs): def _log(self, *args, **kwargs):
"""Append action_id if available to the extra.""" """Append action_id if available to the extra."""
if self.action_id is not None: if self.action_id is not None:
extra = kwargs.get('extra', {}) extra = kwargs.get("extra", {})
if 'action_id' not in extra: if "action_id" not in extra:
# FIXME: Get real action_id instead of logger/current one # FIXME: Get real action_id instead of logger/current one
extra['action_id'] = _get_action_id() extra["action_id"] = _get_action_id()
kwargs['extra'] = extra kwargs["extra"] = extra
return Logger._log(self, *args, **kwargs) return Logger._log(self, *args, **kwargs)
@ -120,7 +124,7 @@ action_id = 0
def _get_action_id(): def _get_action_id():
return '%d.%d' % (pid, action_id) return "%d.%d" % (pid, action_id)
def start_action_logging(): def start_action_logging():
@ -146,7 +150,7 @@ def getActionLogger(name=None, logger=None, action_id=None):
""" """
if not name and not logger: if not name and not logger:
raise ValueError('Either a name or a logger must be specified') raise ValueError("Either a name or a logger must be specified")
logger = logger or getLogger(name) logger = logger or getLogger(name)
logger.action_id = action_id if action_id else _get_action_id() logger.action_id = action_id if action_id else _get_action_id()
@ -164,15 +168,15 @@ class ActionFilter(object):
""" """
def __init__(self, message_key='fmessage', strict=False): def __init__(self, message_key="fmessage", strict=False):
self.message_key = message_key self.message_key = message_key
self.strict = strict self.strict = strict
def filter(self, record): def filter(self, record):
msg = record.getMessage() msg = record.getMessage()
action_id = record.__dict__.get('action_id', None) action_id = record.__dict__.get("action_id", None)
if action_id is not None: if action_id is not None:
msg = '[{:s}] {:s}'.format(action_id, msg) msg = "[{:s}] {:s}".format(action_id, msg)
elif self.strict: elif self.strict:
return False return False
record.__dict__[self.message_key] = msg record.__dict__[self.message_key] = msg

View file

@ -15,6 +15,7 @@ def download_text(url, timeout=30, expected_status_code=200):
None to ignore the status code. None to ignore the status code.
""" """
import requests # lazy loading this module for performance reasons import requests # lazy loading this module for performance reasons
# Assumptions # Assumptions
assert isinstance(url, str) assert isinstance(url, str)
@ -23,22 +24,21 @@ def download_text(url, timeout=30, expected_status_code=200):
r = requests.get(url, timeout=timeout) r = requests.get(url, timeout=timeout)
# Invalid URL # Invalid URL
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
raise MoulinetteError('invalid_url', url=url) raise MoulinetteError("invalid_url", url=url)
# SSL exceptions # SSL exceptions
except requests.exceptions.SSLError: except requests.exceptions.SSLError:
raise MoulinetteError('download_ssl_error', url=url) raise MoulinetteError("download_ssl_error", url=url)
# Timeout exceptions # Timeout exceptions
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
raise MoulinetteError('download_timeout', url=url) raise MoulinetteError("download_timeout", url=url)
# Unknown stuff # Unknown stuff
except Exception as e: except Exception as e:
raise MoulinetteError('download_unknown_error', raise MoulinetteError("download_unknown_error", url=url, error=str(e))
url=url, error=str(e))
# Assume error if status code is not 200 (OK) # Assume error if status code is not 200 (OK)
if expected_status_code is not None \ if expected_status_code is not None and r.status_code != expected_status_code:
and r.status_code != expected_status_code: raise MoulinetteError(
raise MoulinetteError('download_bad_status_code', "download_bad_status_code", url=url, code=str(r.status_code)
url=url, code=str(r.status_code)) )
return r.text return r.text
@ -59,6 +59,6 @@ def download_json(url, timeout=30, expected_status_code=200):
try: try:
loaded_json = json.loads(text) loaded_json = json.loads(text)
except ValueError as e: except ValueError as e:
raise MoulinetteError('corrupted_json', ressource=url, error=e) raise MoulinetteError("corrupted_json", ressource=url, error=e)
return loaded_json return loaded_json

View file

@ -11,6 +11,7 @@ except ImportError:
from shlex import quote # Python3 >= 3.3 from shlex import quote # Python3 >= 3.3
from .stream import async_file_reading from .stream import async_file_reading
quote # This line is here to avoid W0611 PEP8 error (see comments above) quote # This line is here to avoid W0611 PEP8 error (see comments above)
# Prevent to import subprocess only for common classes # Prevent to import subprocess only for common classes
@ -19,6 +20,7 @@ CalledProcessError = subprocess.CalledProcessError
# Alternative subprocess methods --------------------------------------- # Alternative subprocess methods ---------------------------------------
def check_output(args, stderr=subprocess.STDOUT, shell=True, **kwargs): def check_output(args, stderr=subprocess.STDOUT, shell=True, **kwargs):
"""Run command with arguments and return its output as a byte string """Run command with arguments and return its output as a byte string
@ -31,6 +33,7 @@ def check_output(args, stderr=subprocess.STDOUT, shell=True, **kwargs):
# Call with stream access ---------------------------------------------- # Call with stream access ----------------------------------------------
def call_async_output(args, callback, **kwargs): def call_async_output(args, callback, **kwargs):
"""Run command and provide its output asynchronously """Run command and provide its output asynchronously
@ -52,10 +55,9 @@ def call_async_output(args, callback, **kwargs):
Exit status of the command Exit status of the command
""" """
for a in ['stdout', 'stderr']: for a in ["stdout", "stderr"]:
if a in kwargs: if a in kwargs:
raise ValueError('%s argument not allowed, ' raise ValueError("%s argument not allowed, " "it will be overridden." % a)
'it will be overridden.' % a)
if "stdinfo" in kwargs and kwargs["stdinfo"] is not None: if "stdinfo" in kwargs and kwargs["stdinfo"] is not None:
assert len(callback) == 3 assert len(callback) == 3
@ -72,16 +74,16 @@ def call_async_output(args, callback, **kwargs):
# Validate callback argument # Validate callback argument
if isinstance(callback, tuple): if isinstance(callback, tuple):
if len(callback) < 2: if len(callback) < 2:
raise ValueError('callback argument should be a 2-tuple') raise ValueError("callback argument should be a 2-tuple")
kwargs['stdout'] = kwargs['stderr'] = subprocess.PIPE kwargs["stdout"] = kwargs["stderr"] = subprocess.PIPE
separate_stderr = True separate_stderr = True
elif callable(callback): elif callable(callback):
kwargs['stdout'] = subprocess.PIPE kwargs["stdout"] = subprocess.PIPE
kwargs['stderr'] = subprocess.STDOUT kwargs["stderr"] = subprocess.STDOUT
separate_stderr = False separate_stderr = False
callback = (callback,) callback = (callback,)
else: else:
raise ValueError('callback argument must be callable or a 2-tuple') raise ValueError("callback argument must be callable or a 2-tuple")
# Run the command # Run the command
p = subprocess.Popen(args, **kwargs) p = subprocess.Popen(args, **kwargs)
@ -101,7 +103,7 @@ def call_async_output(args, callback, **kwargs):
stderr_consum.process_next_line() stderr_consum.process_next_line()
if stdinfo: if stdinfo:
stdinfo_consum.process_next_line() stdinfo_consum.process_next_line()
time.sleep(.1) time.sleep(0.1)
stderr_reader.join() stderr_reader.join()
# clear the queues # clear the queues
stdout_consum.process_current_queue() stdout_consum.process_current_queue()
@ -111,7 +113,7 @@ def call_async_output(args, callback, **kwargs):
else: else:
while not stdout_reader.eof(): while not stdout_reader.eof():
stdout_consum.process_current_queue() stdout_consum.process_current_queue()
time.sleep(.1) time.sleep(0.1)
stdout_reader.join() stdout_reader.join()
# clear the queue # clear the queue
stdout_consum.process_current_queue() stdout_consum.process_current_queue()
@ -131,15 +133,15 @@ def call_async_output(args, callback, **kwargs):
while time.time() - start < 10: while time.time() - start < 10:
if p.poll() is not None: if p.poll() is not None:
return p.poll() return p.poll()
time.sleep(.1) time.sleep(0.1)
return p.poll() return p.poll()
# Call multiple commands ----------------------------------------------- # Call multiple commands -----------------------------------------------
def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
**kwargs): def run_commands(cmds, callback=None, separate_stderr=False, shell=True, **kwargs):
"""Run multiple commands with error management """Run multiple commands with error management
Run a list of commands and allow to manage how to treat errors either Run a list of commands and allow to manage how to treat errors either
@ -176,18 +178,18 @@ def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
# stdout and stderr are specified by this code later, so they cannot be # stdout and stderr are specified by this code later, so they cannot be
# overriden by user input # overriden by user input
for a in ['stdout', 'stderr']: for a in ["stdout", "stderr"]:
if a in kwargs: if a in kwargs:
raise ValueError('%s argument not allowed, ' raise ValueError("%s argument not allowed, " "it will be overridden." % a)
'it will be overridden.' % a)
# If no callback specified... # If no callback specified...
if callback is None: if callback is None:
# Raise CalledProcessError on command failure # Raise CalledProcessError on command failure
def callback(r, c, o): def callback(r, c, o):
raise CalledProcessError(r, c, o) raise CalledProcessError(r, c, o)
elif not callable(callback): elif not callable(callback):
raise ValueError('callback argument must be callable') raise ValueError("callback argument must be callable")
# Manage stderr # Manage stderr
if separate_stderr: if separate_stderr:
@ -201,8 +203,9 @@ def run_commands(cmds, callback=None, separate_stderr=False, shell=True,
error = 0 error = 0
for cmd in cmds: for cmd in cmds:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, process = subprocess.Popen(
stderr=_stderr, shell=shell, **kwargs) cmd, stdout=subprocess.PIPE, stderr=_stderr, shell=shell, **kwargs
)
output = _get_output(*process.communicate()) output = _get_output(*process.communicate())
retcode = process.poll() retcode = process.poll()

View file

@ -3,11 +3,12 @@ from json.encoder import JSONEncoder
import datetime import datetime
import pytz import pytz
logger = logging.getLogger('moulinette.utils.serialize') logger = logging.getLogger("moulinette.utils.serialize")
# JSON utilities ------------------------------------------------------- # JSON utilities -------------------------------------------------------
class JSONExtendedEncoder(JSONEncoder): class JSONExtendedEncoder(JSONEncoder):
"""Extended JSON encoder """Extended JSON encoder
@ -24,8 +25,7 @@ class JSONExtendedEncoder(JSONEncoder):
def default(self, o): def default(self, o):
"""Return a serializable object""" """Return a serializable object"""
# Convert compatible containers into list # Convert compatible containers into list
if isinstance(o, set) or ( if isinstance(o, set) or (hasattr(o, "__iter__") and hasattr(o, "next")):
hasattr(o, '__iter__') and hasattr(o, 'next')):
return list(o) return list(o)
# Display the date in its iso format ISO-8601 Internet Profile (RFC 3339) # Display the date in its iso format ISO-8601 Internet Profile (RFC 3339)
@ -35,6 +35,9 @@ class JSONExtendedEncoder(JSONEncoder):
return o.isoformat() return o.isoformat()
# Return the repr for object that json can't encode # Return the repr for object that json can't encode
logger.warning('cannot properly encode in JSON the object %s, ' logger.warning(
'returned repr is: %r', type(o), o) "cannot properly encode in JSON the object %s, " "returned repr is: %r",
type(o),
o,
)
return repr(o) return repr(o)

View file

@ -7,6 +7,7 @@ from multiprocessing.queues import SimpleQueue
# Read from a stream --------------------------------------------------- # Read from a stream ---------------------------------------------------
class AsynchronousFileReader(Process): class AsynchronousFileReader(Process):
""" """
@ -20,8 +21,8 @@ class AsynchronousFileReader(Process):
""" """
def __init__(self, fd, queue): def __init__(self, fd, queue):
assert hasattr(queue, 'put') assert hasattr(queue, "put")
assert hasattr(queue, 'empty') assert hasattr(queue, "empty")
assert isinstance(fd, int) or callable(fd.readline) assert isinstance(fd, int) or callable(fd.readline)
Process.__init__(self) Process.__init__(self)
self._fd = fd self._fd = fd
@ -34,7 +35,7 @@ class AsynchronousFileReader(Process):
# Typically that's for stdout/stderr pipes # Typically that's for stdout/stderr pipes
# We can read the stuff easily with 'readline' # We can read the stuff easily with 'readline'
if not isinstance(self._fd, int): if not isinstance(self._fd, int):
for line in iter(self._fd.readline, ''): for line in iter(self._fd.readline, ""):
self._queue.put(line) self._queue.put(line)
# Else, it got opened with os.open() and we have to read it # Else, it got opened with os.open() and we have to read it
@ -52,10 +53,10 @@ class AsynchronousFileReader(Process):
# If we have data, extract a line (ending with \n) and feed # If we have data, extract a line (ending with \n) and feed
# it to the consumer # it to the consumer
if data and '\n' in data: if data and "\n" in data:
lines = data.split('\n') lines = data.split("\n")
self._queue.put(lines[0]) self._queue.put(lines[0])
data = '\n'.join(lines[1:]) data = "\n".join(lines[1:])
else: else:
time.sleep(0.05) time.sleep(0.05)
@ -75,7 +76,6 @@ class AsynchronousFileReader(Process):
class Consummer(object): class Consummer(object):
def __init__(self, queue, callback): def __init__(self, queue, callback):
self.queue = queue self.queue = queue
self.callback = callback self.callback = callback

View file

@ -6,6 +6,7 @@ import binascii
# Pattern searching ---------------------------------------------------- # Pattern searching ----------------------------------------------------
def search(pattern, text, count=0, flags=0): def search(pattern, text, count=0, flags=0):
"""Search for pattern in a text """Search for pattern in a text
@ -46,7 +47,7 @@ def searchf(pattern, path, count=0, flags=re.MULTILINE):
content by using the search function. content by using the search function.
""" """
with open(path, 'r+') as f: with open(path, "r+") as f:
data = mmap.mmap(f.fileno(), 0) data = mmap.mmap(f.fileno(), 0)
match = search(pattern, data, count, flags) match = search(pattern, data, count, flags)
data.close() data.close()
@ -55,6 +56,7 @@ def searchf(pattern, path, count=0, flags=re.MULTILINE):
# Text formatting ------------------------------------------------------ # Text formatting ------------------------------------------------------
def prependlines(text, prepend): def prependlines(text, prepend):
"""Prepend a string to each line of a text""" """Prepend a string to each line of a text"""
lines = text.splitlines(True) lines = text.splitlines(True)
@ -63,6 +65,7 @@ def prependlines(text, prepend):
# Randomize ------------------------------------------------------------ # Randomize ------------------------------------------------------------
def random_ascii(length=20): def random_ascii(length=20):
"""Return a random ascii string""" """Return a random ascii string"""
return binascii.hexlify(os.urandom(length)).decode('ascii') return binascii.hexlify(os.urandom(length)).decode("ascii")

View file

@ -10,9 +10,9 @@ def patch_init(moulinette):
"""Configure moulinette to use the YunoHost namespace.""" """Configure moulinette to use the YunoHost namespace."""
old_init = moulinette.core.Moulinette18n.__init__ old_init = moulinette.core.Moulinette18n.__init__
def monkey_path_i18n_init(self, package, default_locale='en'): def monkey_path_i18n_init(self, package, default_locale="en"):
old_init(self, package, default_locale) old_init(self, package, default_locale)
self.load_namespace('moulinette') self.load_namespace("moulinette")
moulinette.core.Moulinette18n.__init__ = monkey_path_i18n_init moulinette.core.Moulinette18n.__init__ = monkey_path_i18n_init
@ -23,7 +23,7 @@ def patch_translate(moulinette):
def new_translate(self, key, *args, **kwargs): def new_translate(self, key, *args, **kwargs):
if key not in self._translations[self.default_locale].keys(): if key not in self._translations[self.default_locale].keys():
message = 'Unable to retrieve key %s for default locale!' % key message = "Unable to retrieve key %s for default locale!" % key
raise KeyError(message) raise KeyError(message)
return old_translate(self, key, *args, **kwargs) return old_translate(self, key, *args, **kwargs)
@ -38,59 +38,46 @@ def patch_translate(moulinette):
def patch_logging(moulinette): def patch_logging(moulinette):
"""Configure logging to use the custom logger.""" """Configure logging to use the custom logger."""
handlers = set(['tty', 'api']) handlers = set(["tty", "api"])
root_handlers = set(handlers) root_handlers = set(handlers)
level = 'INFO' level = "INFO"
tty_level = 'INFO' tty_level = "INFO"
return { return {
'version': 1, "version": 1,
'disable_existing_loggers': True, "disable_existing_loggers": True,
'formatters': { "formatters": {
'tty-debug': { "tty-debug": {"format": "%(relativeCreated)-4d %(fmessage)s"},
'format': '%(relativeCreated)-4d %(fmessage)s' "precise": {
}, "format": "%(asctime)-15s %(levelname)-8s %(name)s %(funcName)s - %(fmessage)s" # noqa
'precise': {
'format': '%(asctime)-15s %(levelname)-8s %(name)s %(funcName)s - %(fmessage)s' # noqa
}, },
}, },
'filters': { "filters": {"action": {"()": "moulinette.utils.log.ActionFilter",},},
'action': { "handlers": {
'()': 'moulinette.utils.log.ActionFilter', "api": {
"level": level,
"class": "moulinette.interfaces.api.APIQueueHandler",
},
"tty": {
"level": tty_level,
"class": "moulinette.interfaces.cli.TTYHandler",
"formatter": "",
}, },
}, },
'handlers': { "loggers": {
'api': { "moulinette": {"level": level, "handlers": [], "propagate": True,},
'level': level, "moulinette.interface": {
'class': 'moulinette.interfaces.api.APIQueueHandler', "level": level,
}, "handlers": handlers,
'tty': { "propagate": False,
'level': tty_level,
'class': 'moulinette.interfaces.cli.TTYHandler',
'formatter': '',
}, },
}, },
'loggers': { "root": {"level": level, "handlers": root_handlers,},
'moulinette': {
'level': level,
'handlers': [],
'propagate': True,
},
'moulinette.interface': {
'level': level,
'handlers': handlers,
'propagate': False,
},
},
'root': {
'level': level,
'handlers': root_handlers,
},
} }
@pytest.fixture(scope='session', autouse=True) @pytest.fixture(scope="session", autouse=True)
def moulinette(tmp_path_factory): def moulinette(tmp_path_factory):
import moulinette import moulinette
@ -100,9 +87,9 @@ def moulinette(tmp_path_factory):
tmp_cache = str(tmp_path_factory.mktemp("cache")) tmp_cache = str(tmp_path_factory.mktemp("cache"))
tmp_data = str(tmp_path_factory.mktemp("data")) tmp_data = str(tmp_path_factory.mktemp("data"))
tmp_lib = str(tmp_path_factory.mktemp("lib")) tmp_lib = str(tmp_path_factory.mktemp("lib"))
os.environ['MOULINETTE_CACHE_DIR'] = tmp_cache os.environ["MOULINETTE_CACHE_DIR"] = tmp_cache
os.environ['MOULINETTE_DATA_DIR'] = tmp_data os.environ["MOULINETTE_DATA_DIR"] = tmp_data
os.environ['MOULINETTE_LIB_DIR'] = tmp_lib os.environ["MOULINETTE_LIB_DIR"] = tmp_lib
shutil.copytree("./test/actionsmap", "%s/actionsmap" % tmp_data) shutil.copytree("./test/actionsmap", "%s/actionsmap" % tmp_data)
shutil.copytree("./test/src", "%s/%s" % (tmp_lib, namespace)) shutil.copytree("./test/src", "%s/%s" % (tmp_lib, namespace))
shutil.copytree("./test/locales", "%s/%s/locales" % (tmp_lib, namespace)) shutil.copytree("./test/locales", "%s/%s/locales" % (tmp_lib, namespace))
@ -111,10 +98,7 @@ def moulinette(tmp_path_factory):
patch_translate(moulinette) patch_translate(moulinette)
logging = patch_logging(moulinette) logging = patch_logging(moulinette)
moulinette.init( moulinette.init(logging_config=logging, _from_source=False)
logging_config=logging,
_from_source=False
)
return moulinette return moulinette
@ -129,12 +113,13 @@ def moulinette_webapi(moulinette):
# sure why :| # sure why :|
def return_true(self, cookie, request): def return_true(self, cookie, request):
return True return True
CookiePolicy.return_ok_secure = return_true CookiePolicy.return_ok_secure = return_true
moulinette_webapi = moulinette.core.init_interface( moulinette_webapi = moulinette.core.init_interface(
'api', "api",
kwargs={'routes': {}, 'use_websocket': False}, kwargs={"routes": {}, "use_websocket": False},
actionsmap={'namespaces': ["moulitest"], 'use_cache': True} actionsmap={"namespaces": ["moulitest"], "use_cache": True},
) )
return TestApp(moulinette_webapi._app) return TestApp(moulinette_webapi._app)
@ -142,16 +127,16 @@ def moulinette_webapi(moulinette):
@pytest.fixture @pytest.fixture
def test_file(tmp_path): def test_file(tmp_path):
test_text = 'foo\nbar\n' test_text = "foo\nbar\n"
test_file = tmp_path / 'test.txt' test_file = tmp_path / "test.txt"
test_file.write_bytes(test_text) test_file.write_bytes(test_text)
return test_file return test_file
@pytest.fixture @pytest.fixture
def test_json(tmp_path): def test_json(tmp_path):
test_json = json.dumps({'foo': 'bar'}) test_json = json.dumps({"foo": "bar"})
test_file = tmp_path / 'test.json' test_file = tmp_path / "test.json"
test_file.write_bytes(test_json) test_file.write_bytes(test_json)
return test_file return test_file
@ -163,4 +148,4 @@ def user():
@pytest.fixture @pytest.fixture
def test_url(): def test_url():
return 'https://some.test.url/yolo.txt' return "https://some.test.url/yolo.txt"

View file

@ -5,7 +5,7 @@ from moulinette.actionsmap import (
AskParameter, AskParameter,
PatternParameter, PatternParameter,
RequiredParameter, RequiredParameter,
ActionsMap ActionsMap,
) )
from moulinette.interfaces import BaseActionsMapParser from moulinette.interfaces import BaseActionsMapParser
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
@ -13,75 +13,73 @@ from moulinette.core import MoulinetteError
@pytest.fixture @pytest.fixture
def iface(): def iface():
return 'iface' return "iface"
def test_comment_parameter_bad_bool_value(iface, caplog): def test_comment_parameter_bad_bool_value(iface, caplog):
comment = CommentParameter(iface) comment = CommentParameter(iface)
assert comment.validate(True, 'a') == 'a' assert comment.validate(True, "a") == "a"
assert any('expecting a non-empty string' in message for message in caplog.messages) assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_comment_parameter_bad_empty_string(iface, caplog): def test_comment_parameter_bad_empty_string(iface, caplog):
comment = CommentParameter(iface) comment = CommentParameter(iface)
assert comment.validate('', 'a') == 'a' assert comment.validate("", "a") == "a"
assert any('expecting a non-empty string' in message for message in caplog.messages) assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_comment_parameter_bad_type(iface): def test_comment_parameter_bad_type(iface):
comment = CommentParameter(iface) comment = CommentParameter(iface)
with pytest.raises(TypeError): with pytest.raises(TypeError):
comment.validate({}, 'b') comment.validate({}, "b")
def test_ask_parameter_bad_bool_value(iface, caplog): def test_ask_parameter_bad_bool_value(iface, caplog):
ask = AskParameter(iface) ask = AskParameter(iface)
assert ask.validate(True, 'a') == 'a' assert ask.validate(True, "a") == "a"
assert any('expecting a non-empty string' in message for message in caplog.messages) assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_ask_parameter_bad_empty_string(iface, caplog): def test_ask_parameter_bad_empty_string(iface, caplog):
ask = AskParameter(iface) ask = AskParameter(iface)
assert ask.validate('', 'a') == 'a' assert ask.validate("", "a") == "a"
assert any('expecting a non-empty string' in message for message in caplog.messages) assert any("expecting a non-empty string" in message for message in caplog.messages)
def test_ask_parameter_bad_type(iface): def test_ask_parameter_bad_type(iface):
ask = AskParameter(iface) ask = AskParameter(iface)
with pytest.raises(TypeError): with pytest.raises(TypeError):
ask.validate({}, 'b') ask.validate({}, "b")
def test_pattern_parameter_bad_str_value(iface, caplog): def test_pattern_parameter_bad_str_value(iface, caplog):
pattern = PatternParameter(iface) pattern = PatternParameter(iface)
assert pattern.validate('', 'a') == ['', 'pattern_not_match'] assert pattern.validate("", "a") == ["", "pattern_not_match"]
assert any('expecting a list' in message for message in caplog.messages) assert any("expecting a list" in message for message in caplog.messages)
@pytest.mark.parametrize('iface', [ @pytest.mark.parametrize(
[], "iface", [[], ["pattern_alone"], ["pattern", "message", "extra stuff"]]
['pattern_alone'], )
['pattern', 'message', 'extra stuff']
])
def test_pattern_parameter_bad_list_len(iface): def test_pattern_parameter_bad_list_len(iface):
pattern = PatternParameter(iface) pattern = PatternParameter(iface)
with pytest.raises(TypeError): with pytest.raises(TypeError):
pattern.validate(iface, 'a') pattern.validate(iface, "a")
def test_required_paremeter_missing_value(iface): def test_required_paremeter_missing_value(iface):
required = RequiredParameter(iface) required = RequiredParameter(iface)
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
required(True, 'a', '') required(True, "a", "")
assert 'is required' in str(exception) assert "is required" in str(exception)
def test_actions_map_unknown_authenticator(monkeypatch, tmp_path): def test_actions_map_unknown_authenticator(monkeypatch, tmp_path):
monkeypatch.setenv('MOULINETTE_DATA_DIR', str(tmp_path)) monkeypatch.setenv("MOULINETTE_DATA_DIR", str(tmp_path))
actionsmap_dir = actionsmap_dir = tmp_path / 'actionsmap' actionsmap_dir = actionsmap_dir = tmp_path / "actionsmap"
actionsmap_dir.mkdir() actionsmap_dir.mkdir()
amap = ActionsMap(BaseActionsMapParser) amap = ActionsMap(BaseActionsMapParser)
with pytest.raises(ValueError) as exception: with pytest.raises(ValueError) as exception:
amap.get_authenticator_for_profile('unknown') amap.get_authenticator_for_profile("unknown")
assert 'Unknown authenticator' in str(exception) assert "Unknown authenticator" in str(exception)

View file

@ -7,19 +7,28 @@ def login(webapi, csrf=False, profile=None, status=200):
if profile: if profile:
data["profile"] = profile data["profile"] = profile
return webapi.post("/login", data, return webapi.post(
status=status, "/login",
headers=None if csrf else {"X-Requested-With": ""}) data,
status=status,
headers=None if csrf else {"X-Requested-With": ""},
)
def test_request_no_auth_needed(moulinette_webapi): def test_request_no_auth_needed(moulinette_webapi):
assert moulinette_webapi.get("/test-auth/none", status=200).text == '"some_data_from_none"' assert (
moulinette_webapi.get("/test-auth/none", status=200).text
== '"some_data_from_none"'
)
def test_request_with_auth_but_not_logged(moulinette_webapi): def test_request_with_auth_but_not_logged(moulinette_webapi):
assert moulinette_webapi.get("/test-auth/default", status=401).text == "Authentication required" assert (
moulinette_webapi.get("/test-auth/default", status=401).text
== "Authentication required"
)
def test_login(moulinette_webapi): def test_login(moulinette_webapi):
@ -29,8 +38,10 @@ def test_login(moulinette_webapi):
assert "session.id" in moulinette_webapi.cookies assert "session.id" in moulinette_webapi.cookies
assert "session.tokens" in moulinette_webapi.cookies assert "session.tokens" in moulinette_webapi.cookies
cache_session_default = os.environ['MOULINETTE_CACHE_DIR'] + "/session/default/" cache_session_default = os.environ["MOULINETTE_CACHE_DIR"] + "/session/default/"
assert moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(cache_session_default) assert moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(
cache_session_default
)
def test_login_csrf_attempt(moulinette_webapi): def test_login_csrf_attempt(moulinette_webapi):
@ -57,7 +68,10 @@ def test_login_then_legit_request(moulinette_webapi):
login(moulinette_webapi) login(moulinette_webapi)
assert moulinette_webapi.get("/test-auth/default", status=200).text == '"some_data_from_default"' assert (
moulinette_webapi.get("/test-auth/default", status=200).text
== '"some_data_from_default"'
)
def test_login_then_logout(moulinette_webapi): def test_login_then_logout(moulinette_webapi):
@ -66,7 +80,12 @@ def test_login_then_logout(moulinette_webapi):
moulinette_webapi.get("/logout", status=200) moulinette_webapi.get("/logout", status=200)
cache_session_default = os.environ['MOULINETTE_CACHE_DIR'] + "/session/default/" cache_session_default = os.environ["MOULINETTE_CACHE_DIR"] + "/session/default/"
assert not moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(cache_session_default) assert not moulinette_webapi.cookies["session.id"] + ".asc" in os.listdir(
cache_session_default
)
assert moulinette_webapi.get("/test-auth/default", status=401).text == "Authentication required" assert (
moulinette_webapi.get("/test-auth/default", status=401).text
== "Authentication required"
)

View file

@ -2,11 +2,11 @@ import os.path
def test_open_cachefile_creates(monkeypatch, tmp_path): def test_open_cachefile_creates(monkeypatch, tmp_path):
monkeypatch.setenv('MOULINETTE_CACHE_DIR', str(tmp_path)) monkeypatch.setenv("MOULINETTE_CACHE_DIR", str(tmp_path))
from moulinette.cache import open_cachefile from moulinette.cache import open_cachefile
handle = open_cachefile('foo.cache', mode='w') handle = open_cachefile("foo.cache", mode="w")
assert handle.mode == 'w' assert handle.mode == "w"
assert handle.name == os.path.join(str(tmp_path), 'foo.cache') assert handle.name == os.path.join(str(tmp_path), "foo.cache")

View file

@ -4,150 +4,156 @@ import pytest
from moulinette import m18n from moulinette import m18n
from moulinette.core import MoulinetteError from moulinette.core import MoulinetteError
from moulinette.utils.filesystem import (append_to_file, read_file, read_json, from moulinette.utils.filesystem import (
rm, write_to_file, write_to_json) append_to_file,
read_file,
read_json,
rm,
write_to_file,
write_to_json,
)
def test_read_file(test_file): def test_read_file(test_file):
content = read_file(str(test_file)) content = read_file(str(test_file))
assert content == 'foo\nbar\n' assert content == "foo\nbar\n"
def test_read_file_missing_file(): def test_read_file_missing_file():
bad_file = 'doesnt-exist' bad_file = "doesnt-exist"
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
read_file(bad_file) read_file(bad_file)
translation = m18n.g('file_not_exist', path=bad_file) translation = m18n.g("file_not_exist", path=bad_file)
expected_msg = translation.format(path=bad_file) expected_msg = translation.format(path=bad_file)
assert expected_msg in str(exception) assert expected_msg in str(exception)
def test_read_file_cannot_read_ioerror(test_file, mocker): def test_read_file_cannot_read_ioerror(test_file, mocker):
error = 'foobar' error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)): with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
read_file(str(test_file)) read_file(str(test_file))
translation = m18n.g('cannot_open_file', file=str(test_file), error=error) translation = m18n.g("cannot_open_file", file=str(test_file), error=error)
expected_msg = translation.format(file=str(test_file), error=error) expected_msg = translation.format(file=str(test_file), error=error)
assert expected_msg in str(exception) assert expected_msg in str(exception)
def test_read_json(test_json): def test_read_json(test_json):
content = read_json(str(test_json)) content = read_json(str(test_json))
assert 'foo' in content.keys() assert "foo" in content.keys()
assert content['foo'] == 'bar' assert content["foo"] == "bar"
def test_read_json_cannot_read(test_json, mocker): def test_read_json_cannot_read(test_json, mocker):
error = 'foobar' error = "foobar"
with mocker.patch('json.loads', side_effect=ValueError(error)): with mocker.patch("json.loads", side_effect=ValueError(error)):
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
read_json(str(test_json)) read_json(str(test_json))
translation = m18n.g('corrupted_json', ressource=str(test_json), error=error) translation = m18n.g("corrupted_json", ressource=str(test_json), error=error)
expected_msg = translation.format(ressource=str(test_json), error=error) expected_msg = translation.format(ressource=str(test_json), error=error)
assert expected_msg in str(exception) assert expected_msg in str(exception)
def test_write_to_existing_file(test_file): def test_write_to_existing_file(test_file):
write_to_file(str(test_file), 'yolo\nswag') write_to_file(str(test_file), "yolo\nswag")
assert read_file(str(test_file)) == 'yolo\nswag' assert read_file(str(test_file)) == "yolo\nswag"
def test_write_to_new_file(tmp_path): def test_write_to_new_file(tmp_path):
new_file = tmp_path / 'newfile.txt' new_file = tmp_path / "newfile.txt"
write_to_file(str(new_file), 'yolo\nswag') write_to_file(str(new_file), "yolo\nswag")
assert os.path.exists(str(new_file)) assert os.path.exists(str(new_file))
assert read_file(str(new_file)) == 'yolo\nswag' assert read_file(str(new_file)) == "yolo\nswag"
def test_write_to_existing_file_bad_perms(test_file, mocker): def test_write_to_existing_file_bad_perms(test_file, mocker):
error = 'foobar' error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)): with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
write_to_file(str(test_file), 'yolo\nswag') write_to_file(str(test_file), "yolo\nswag")
translation = m18n.g('cannot_write_file', file=str(test_file), error=error) translation = m18n.g("cannot_write_file", file=str(test_file), error=error)
expected_msg = translation.format(file=str(test_file), error=error) expected_msg = translation.format(file=str(test_file), error=error)
assert expected_msg in str(exception) assert expected_msg in str(exception)
def test_write_cannot_write_folder(tmp_path): def test_write_cannot_write_folder(tmp_path):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
write_to_file(str(tmp_path), 'yolo\nswag') write_to_file(str(tmp_path), "yolo\nswag")
def test_write_cannot_write_to_non_existant_folder(): def test_write_cannot_write_to_non_existant_folder():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
write_to_file('/toto/test', 'yolo\nswag') write_to_file("/toto/test", "yolo\nswag")
def test_write_to_file_with_a_list(test_file): def test_write_to_file_with_a_list(test_file):
write_to_file(str(test_file), ['yolo', 'swag']) write_to_file(str(test_file), ["yolo", "swag"])
assert read_file(str(test_file)) == 'yolo\nswag' assert read_file(str(test_file)) == "yolo\nswag"
def test_append_to_existing_file(test_file): def test_append_to_existing_file(test_file):
append_to_file(str(test_file), 'yolo\nswag') append_to_file(str(test_file), "yolo\nswag")
assert read_file(str(test_file)) == 'foo\nbar\nyolo\nswag' assert read_file(str(test_file)) == "foo\nbar\nyolo\nswag"
def test_append_to_new_file(tmp_path): def test_append_to_new_file(tmp_path):
new_file = tmp_path / 'newfile.txt' new_file = tmp_path / "newfile.txt"
append_to_file(str(new_file), 'yolo\nswag') append_to_file(str(new_file), "yolo\nswag")
assert os.path.exists(str(new_file)) assert os.path.exists(str(new_file))
assert read_file(str(new_file)) == 'yolo\nswag' assert read_file(str(new_file)) == "yolo\nswag"
def text_write_dict_to_json(tmp_path): def text_write_dict_to_json(tmp_path):
new_file = tmp_path / 'newfile.json' new_file = tmp_path / "newfile.json"
dummy_dict = {'foo': 42, 'bar': ['a', 'b', 'c']} dummy_dict = {"foo": 42, "bar": ["a", "b", "c"]}
write_to_json(str(new_file), dummy_dict) write_to_json(str(new_file), dummy_dict)
_json = read_json(str(new_file)) _json = read_json(str(new_file))
assert 'foo' in _json.keys() assert "foo" in _json.keys()
assert 'bar' in _json.keys() assert "bar" in _json.keys()
assert _json['foo'] == 42 assert _json["foo"] == 42
assert _json['bar'] == ['a', 'b', 'c'] assert _json["bar"] == ["a", "b", "c"]
def text_write_list_to_json(tmp_path): def text_write_list_to_json(tmp_path):
new_file = tmp_path / 'newfile.json' new_file = tmp_path / "newfile.json"
dummy_list = ['foo', 'bar', 'baz'] dummy_list = ["foo", "bar", "baz"]
write_to_json(str(new_file), dummy_list) write_to_json(str(new_file), dummy_list)
_json = read_json(str(new_file)) _json = read_json(str(new_file))
assert _json == ['foo', 'bar', 'baz'] assert _json == ["foo", "bar", "baz"]
def test_write_to_json_bad_perms(test_json, mocker): def test_write_to_json_bad_perms(test_json, mocker):
error = 'foobar' error = "foobar"
with mocker.patch('__builtin__.open', side_effect=IOError(error)): with mocker.patch("__builtin__.open", side_effect=IOError(error)):
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
write_to_json(str(test_json), {'a': 1}) write_to_json(str(test_json), {"a": 1})
translation = m18n.g('cannot_write_file', file=str(test_json), error=error) translation = m18n.g("cannot_write_file", file=str(test_json), error=error)
expected_msg = translation.format(file=str(test_json), error=error) expected_msg = translation.format(file=str(test_json), error=error)
assert expected_msg in str(exception) assert expected_msg in str(exception)
def test_write_json_cannot_write_to_non_existant_folder(): def test_write_json_cannot_write_to_non_existant_folder():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
write_to_json('/toto/test.json', ['a', 'b']) write_to_json("/toto/test.json", ["a", "b"])
def test_remove_file(test_file): def test_remove_file(test_file):
@ -157,13 +163,13 @@ def test_remove_file(test_file):
def test_remove_file_bad_perms(test_file, mocker): def test_remove_file_bad_perms(test_file, mocker):
error = 'foobar' error = "foobar"
with mocker.patch('os.remove', side_effect=OSError(error)): with mocker.patch("os.remove", side_effect=OSError(error)):
with pytest.raises(MoulinetteError) as exception: with pytest.raises(MoulinetteError) as exception:
rm(str(test_file)) rm(str(test_file))
translation = m18n.g('error_removing', path=str(test_file), error=error) translation = m18n.g("error_removing", path=str(test_file), error=error)
expected_msg = translation.format(path=str(test_file), error=error) expected_msg = translation.format(path=str(test_file), error=error)
assert expected_msg in str(exception) assert expected_msg in str(exception)

View file

@ -8,19 +8,19 @@ from moulinette.utils.network import download_json, download_text
def test_download(test_url): def test_download(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='some text') mock.register_uri("GET", test_url, text="some text")
fetched_text = download_text(test_url) fetched_text = download_text(test_url)
assert fetched_text == 'some text' assert fetched_text == "some text"
def test_download_bad_url(): def test_download_bad_url():
with pytest.raises(MoulinetteError): with pytest.raises(MoulinetteError):
download_text('Nowhere') download_text("Nowhere")
def test_download_404(test_url): def test_download_404(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, status_code=404) mock.register_uri("GET", test_url, status_code=404)
with pytest.raises(MoulinetteError): with pytest.raises(MoulinetteError):
download_text(test_url) download_text(test_url)
@ -28,7 +28,7 @@ def test_download_404(test_url):
def test_download_ssl_error(test_url): def test_download_ssl_error(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
exception = requests.exceptions.SSLError exception = requests.exceptions.SSLError
mock.register_uri('GET', test_url, exc=exception) mock.register_uri("GET", test_url, exc=exception)
with pytest.raises(MoulinetteError): with pytest.raises(MoulinetteError):
download_text(test_url) download_text(test_url)
@ -36,21 +36,21 @@ def test_download_ssl_error(test_url):
def test_download_timeout(test_url): def test_download_timeout(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
exception = requests.exceptions.ConnectTimeout exception = requests.exceptions.ConnectTimeout
mock.register_uri('GET', test_url, exc=exception) mock.register_uri("GET", test_url, exc=exception)
with pytest.raises(MoulinetteError): with pytest.raises(MoulinetteError):
download_text(test_url) download_text(test_url)
def test_download_json(test_url): def test_download_json(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='{"foo":"bar"}') mock.register_uri("GET", test_url, text='{"foo":"bar"}')
fetched_json = download_json(test_url) fetched_json = download_json(test_url)
assert 'foo' in fetched_json.keys() assert "foo" in fetched_json.keys()
assert fetched_json['foo'] == 'bar' assert fetched_json["foo"] == "bar"
def test_download_json_bad_json(test_url): def test_download_json_bad_json(test_url):
with requests_mock.Mocker() as mock: with requests_mock.Mocker() as mock:
mock.register_uri('GET', test_url, text='notjsonlol') mock.register_uri("GET", test_url, text="notjsonlol")
with pytest.raises(MoulinetteError): with pytest.raises(MoulinetteError):
download_json(test_url) download_json(test_url)

View file

@ -8,10 +8,10 @@ from moulinette.utils.process import run_commands
def test_run_shell_command_list(test_file): def test_run_shell_command_list(test_file):
assert os.path.exists(str(test_file)) assert os.path.exists(str(test_file))
run_commands(['rm -f %s' % str(test_file)]) run_commands(["rm -f %s" % str(test_file)])
assert not os.path.exists(str(test_file)) assert not os.path.exists(str(test_file))
def test_run_shell_bad_cmd(): def test_run_shell_bad_cmd():
with pytest.raises(CalledProcessError): with pytest.raises(CalledProcessError):
run_commands(['yolo swag']) run_commands(["yolo swag"])

View file

@ -7,8 +7,8 @@ def test_json_extended_encoder(caplog):
assert encoder.default(set([1, 2, 3])) == [1, 2, 3] assert encoder.default(set([1, 2, 3])) == [1, 2, 3]
assert encoder.default(dt(1917, 3, 8)) == '1917-03-08T00:00:00+00:00' assert encoder.default(dt(1917, 3, 8)) == "1917-03-08T00:00:00+00:00"
assert encoder.default(None) == 'None' assert encoder.default(None) == "None"
for message in caplog.messages: for message in caplog.messages:
assert 'cannot properly encode in JSON' in message assert "cannot properly encode in JSON" in message

View file

@ -2,19 +2,19 @@ from moulinette.utils.text import search, searchf, prependlines, random_ascii
def test_search(): def test_search():
assert search('a', 'a a a') == ['a', 'a', 'a'] assert search("a", "a a a") == ["a", "a", "a"]
assert search('a', 'a a a', count=2) == ['a', 'a'] assert search("a", "a a a", count=2) == ["a", "a"]
assert not search('a', 'c c d') assert not search("a", "c c d")
def test_searchf(test_file): def test_searchf(test_file):
assert searchf('bar', str(test_file)) == ['bar'] assert searchf("bar", str(test_file)) == ["bar"]
assert not searchf('baz', str(test_file)) assert not searchf("baz", str(test_file))
def test_prependlines(): def test_prependlines():
assert prependlines('abc\nedf\nghi', 'XXX') == 'XXXabc\nXXXedf\nXXXghi' assert prependlines("abc\nedf\nghi", "XXX") == "XXXabc\nXXXedf\nXXXghi"
assert prependlines('', 'XXX') == 'XXX' assert prependlines("", "XXX") == "XXX"
def test_random_ascii(): def test_random_ascii():