diff --git a/access.lua b/access.lua index 91a3a45..951d8a8 100644 --- a/access.lua +++ b/access.lua @@ -12,9 +12,130 @@ table.insert(conf["skipped_urls"], conf["portal_domain"]..conf["portal_path"]) -- Dummy intructions ngx.header["X-SSO-WAT"] = "You've just been SSOed" +-- +-- Routing +-- + +-- Logging in/out +-- i.e. http://mydomain.org/?ssologin=myuser + +if ngx.var.request_method == "GET" then + local args = ngx.req.get_uri_args() + + -- In login loop + local user = args.ssologin + if user and login[user] then + return login_walkthrough(user) + end + + -- In logout loop + user = args.ssologout + if user and logout[user] then + return logout_walkthrough(user) + end +end + + +-- Portal +-- i.e. http://mydomain.org/ssowat/* + +if ngx.var.host == conf["portal_domain"] + and string.starts(ngx.var.uri, conf["portal_path"]) +then + + if ngx.var.request_method == "GET" then + + uri_args = ngx.req.get_uri_args() + if uri_args.action and uri_args.action == 'logout' then + -- Logout + return do_logout() + + elseif check_cookie() or ngx.var.uri == conf["portal_path"] then + -- Serve normal portal + return serve(ngx.var.uri) + + else + -- Redirect to portal + return redirect(portal_url) + end + + elseif ngx.var.request_method == "POST" then + + if string.starts(ngx.var.http_referer, portal_url) then + -- CSRF protection + return do_login() + else + -- Redirect to portal + return redirect(portal_url) + end +end + + +-- Skipped urls +-- i.e. http://mydomain.org/no_protection/ + +for _, url in ipairs(conf["skipped_urls"]) do + if string.starts(ngx.var.host..ngx.var.uri, url) then + return pass() + end +end + + +-- Unprotected urls +-- i.e. http://mydomain.org/no_protection+headers/ + +for _, url in ipairs(conf["unprotected_urls"]) do + if string.starts(ngx.var.host..ngx.var.uri, url) then + if check_cookie() then + set_headers(ngx.var.cookie_SSOwAuthUser) + end + return pass() + end +end + + +-- Cookie validation +-- + +if check_cookie() then + set_headers(ngx.var.cookie_SSOwAuthUser) + return pass +else + delete_cookie() +end + + +-- Login with HTTP Auth if credentials are brought +-- + +local auth_header = ngx.req.get_headers()["Authorization"] +if auth_header then + _, _, b64_cred = string.find(auth_header, "^Basic%s+(.+)$") + _, _, user, password = string.find(ngx.decode_base64(b64_cred), "^(.+):(.+)$") + if authenticate(user, password) then + set_headers(user) + return pass() + end +end + +-- Else redirect to portal +-- + +local back_url = ngx.var.scheme .. "://" .. ngx.var.http_host .. ngx.var.uri +return redirect(portal_url.."?r="..ngx.encode_base64(back_url)) + + -- -- Useful functions -- +function read_file(file) + local f = io.open(file, "rb") + if not f then return false + local content = f:read("*all") + f:close() + return content +end + function is_in_table (t, v) for key, value in ipairs(t) do if value == v then return key end @@ -136,34 +257,57 @@ function set_headers (user) end -function display_login_form () - local args = ngx.req.get_uri_args() - ngx.req.set_header("Cache-Control", "no-cache") +-- Yo dawg +function serve(uri) + rel_path = string.gsub(uri, conf["portal_path"], "/") - if args.action and args.action == 'logout' then - if check_cookie() then - local redirect_url = portal_url - if args.r then - redirect_url = ngx.decode_base64(args.r) - end - local user = ngx.var.cookie_SSOwAuthUser - logout[user] = {} - logout[user]["redirect_url"] = redirect_url - logout[user]["domains"] = {} - for _, value in ipairs(conf["domains"]) do - table.insert(logout[user]["domains"], value) - end - return redirect(ngx.var.scheme.."://"..ngx.var.http_host.."/?ssologout="..user) - end + -- Load login.html as index + if rel_path == "/" then + rel_path = "/login.html" end - -- Set redirect - if args.r then - set_redirect_cookie(ngx.decode_base64(args.r)) - ngx.header["Set-Cookie"] = cookies + content = read_file(script_path.."portal/"..rel_path) + if not content then + ngx.exit(ngx.HTTP_NOT_FOUND) end + + -- Extract file extension + _, file, ext = string.match(uri, "(.-)([^\\/]-%.?([^%.\\/]*))$") + + -- Associate to MIME type + mime_types = { + html = "text/html", + js = "text/javascript", + css = "text/css", + gif = "image/gif", + jpg = "image/jpeg", + png = "image/png", + svg = "image/svg+xml", + ico = "image/vnd.microsoft.icon", + } + + -- Set Content-Type + if mime_types[ext] then + ngx.header["Content-Type"] = mime_types[ext] + else + ngx.header["Content-Type"] = "text/plain" + end + + -- Render as mustache + if ext == "html" then + data = get_data_for(file) + content = string.gsub(hige.render(content, data), "(%d+)", "") + end + ngx.header["Cache-Control"] = "no-cache" - return + ngx.say(content) + ngx.exit(ngx.HTTP_OK) +end + +function get_data_for(view) + if view == "login.html" then + return { flash = "Meh" } + end end function do_login () @@ -171,27 +315,45 @@ function do_login () local args = ngx.req.get_post_args() local uri_args = ngx.req.get_uri_args() - if string.starts(ngx.var.http_referer, portal_url) then + if authenticate(args.user, args.password) then ngx.status = ngx.HTTP_CREATED - - if authenticate(args.user, args.password) then - local redirect_url = ngx.var.cookie_SSOwAuthRedirect - if uri_args.r then - redirect_url = ngx.decode_base64(uri_args.r) - end - if not redirect_url then redirect_url = portal_url end - login[args.user] = {} - login[args.user]["redirect_url"] = redirect_url - login[args.user]["domains"] = {} - for _, value in ipairs(conf["domains"]) do - table.insert(login[args.user]["domains"], value) - end - - -- Connect to the first domain (self) - return redirect(ngx.var.scheme.."://"..ngx.var.http_host.."/?ssologin="..args.user) + local redirect_url = ngx.var.cookie_SSOwAuthRedirect + if uri_args.r then + redirect_url = ngx.decode_base64(uri_args.r) end + if not redirect_url then redirect_url = portal_url end + login[args.user] = {} + login[args.user]["redirect_url"] = redirect_url + login[args.user]["domains"] = {} + for _, value in ipairs(conf["domains"]) do + table.insert(login[args.user]["domains"], value) + end + + -- Connect to the first domain (self) + return redirect(ngx.var.scheme.."://"..ngx.var.http_host.."/?ssologin="..args.user) + else + ngx.status = ngx.HTTP_UNAUTHORIZED + return redirect(portal_url) end - return redirect(portal_url) +end + +function do_logout() + local args = ngx.req.get_uri_args() + ngx.req.set_header("Cache-Control", "no-cache") + if check_cookie() then + local redirect_url = portal_url + if args.r then + redirect_url = ngx.decode_base64(args.r) + end + local user = ngx.var.cookie_SSOwAuthUser + logout[user] = {} + logout[user]["redirect_url"] = redirect_url + logout[user]["domains"] = {} + for _, value in ipairs(conf["domains"]) do + table.insert(logout[user]["domains"], value) + end + return redirect(ngx.var.scheme.."://"..ngx.var.http_host.."/?ssologout="..user) + end end function login_walkthrough (user) @@ -245,74 +407,3 @@ function pass () return end --- --- Routing --- - --- Logging in/out -if ngx.var.request_method == "GET" then - local args = ngx.req.get_uri_args() - - local user = args.ssologin - if user and login[user] then - return login_walkthrough(user) - end - - user = args.ssologout - if user and logout[user] then - return logout_walkthrough(user) - end -end - --- Portal -if ngx.var.host == conf["portal_domain"] - and string.starts(ngx.var.uri, conf["portal_path"]) -then - if ngx.var.request_method == "GET" then - return display_login_form() - elseif ngx.var.request_method == "POST" then - return do_login() - end -end - --- Skipped urls -for _, url in ipairs(conf["skipped_urls"]) do - if string.starts(ngx.var.host..ngx.var.uri, url) then - return pass - end -end - --- Unprotected urls -for _, url in ipairs(conf["unprotected_urls"]) do - if string.starts(ngx.var.host..ngx.var.uri, url) then - if check_cookie() then - set_headers(ngx.var.cookie_SSOwAuthUser) - end - return pass - end -end - --- Cookie validation -if check_cookie() then - set_headers(ngx.var.cookie_SSOwAuthUser) - return pass -else - delete_cookie() -end - - --- Login with HTTP Auth if credentials are brought -local auth_header = ngx.req.get_headers()["Authorization"] -if auth_header then - _, _, b64_cred = string.find(auth_header, "^Basic%s+(.+)$") - _, _, user, password = string.find(ngx.decode_base64(b64_cred), "^(.+):(.+)$") - if authenticate(user, password) then - set_headers(user) - return pass - end -end - --- Else redirect to portal -local back_url = ngx.var.scheme .. "://" .. ngx.var.http_host .. ngx.var.uri --- From another domain -return redirect(portal_url.."?r="..ngx.encode_base64(back_url)) diff --git a/conf.json b/conf.json index 05ea178..203d6e0 100644 --- a/conf.json +++ b/conf.json @@ -1,8 +1,7 @@ { + "portal_scheme": "https", "portal_domain": "mydomain.com", "portal_path": "/ssowat/", - "portal_port": "443", - "portal_scheme": "https", "domains": [ "mydomain.com", "myotherdomain.com" diff --git a/hige.lua b/hige.lua new file mode 100644 index 0000000..a648462 --- /dev/null +++ b/hige.lua @@ -0,0 +1,151 @@ +module('hige', package.seeall) + +local tags = { open = '{{', close = '}}' } +local r = {} + +local function merge_environment(...) + local numargs, out = select('#', ...), {} + for i = 1, numargs do + local t = select(i, ...) + if type(t) == 'table' then + for k, v in pairs(t) do + if (type(v) == 'function') then + out[k] = setfenv(v, setmetatable(out, { + __index = getmetatable(getfenv()).__index + })) + else + out[k] = v + end + end + end + end + return out +end + +local function escape(str) + return str:gsub('[&"<>\]', function(c) + if c == '&' then return '&' + elseif c == '"' then return '\"' + elseif c == '\\' then return '\\\\' + elseif c == '<' then return '<' + elseif c == '>' then return '>' + else return c end + end) +end + +local function find(name, context) + local value = context[name] + if value == nil then + return '' + elseif type(value) == 'function' then + return merge_environment(context, value)[name]() + else + return value + end +end + +local operators = { + -- comments + ['!'] = function(state, outer, name, context) + return state.tag_open .. '!' .. outer .. state.tag_close + end, + -- the triple hige is unescaped + ['{'] = function(state, outer, name, context) + return find(name, context) + end, + -- render partial + ['<'] = function(state, outer, name, context) + return r.partial(state, name, context) + end, + -- set new delimiters + ['='] = function(state, outer, name, context) + -- FIXME! + error('setting new delimiters in the template is currently broken') + --[[ + return name:gsub('^(.-)%s+(.-)$', function(open, close) + state.tag_open, state.tag_close = open, close + return '' + end) + ]] + end, +} + +function r.partial(state, name, context) + local target_mt = setmetatable(context, { __index = state.lookup_env }) + local target_name = setfenv(loadstring('return ' .. name), target_mt)() + local target_type = type(target_name) + + if target_type == 'string' then + return r.render(state, target_name, context) + elseif target_type == 'table' then + local target_template = setfenv(loadstring('return '..name..'_template'), target_mt)() + return r.render(state, target_template, merge_environment(target_name, context)) + else + error('unknown partial type "' .. tostring(name) .. '"') + end +end + +function r.tags(state, template, context) + local tag_path = state.tag_open..'([=!<{]?)(%s*([^#/]-)%s*)[=}]?%s*'..state.tag_close + + return template:gsub(tag_path, function(op, outer, name) + if operators[op] ~= nil then + return tostring(operators[op](state, outer, name, context)) + else + return escape(tostring((function() + if name ~= '.' then + return find(name, context) + else + return context + end + end)())) + end + end) +end + +function r.section(state, template, context) + for section_name in template:gmatch(state.tag_open..'#%s*([^#/]-)%s*'..state.tag_close) do + local found, value = context[section_name] ~= nil, find(section_name, context) + local section_path = state.tag_open..'#'..section_name..state.tag_close..'%s*(.*)'..state.tag_open..'/'..section_name..state.tag_close..'%s*' + + template = template:gsub(section_path, function(inner) + if found == false then return '' end + + if value == true then + return r.render(state, inner, context) + elseif type(value) == 'table' then + local output = {} + for _, row in pairs(value) do + local new_context + if type(row) == 'table' then + new_context = merge_environment(context, row) + else + new_context = row + end + table.insert(output, (r.render(state, inner, new_context))) + end + return table.concat(output) + else + return '' + end + end) + end + + return template +end + +function r.render(state, template, context) + return r.tags(state, r.section(state, template, context), context) +end + +function render(template, context, env) + if template:find(tags.open) == nil then return template end + + local state = { + lookup_env = env or _G, + tag_open = tags.open, + tag_close = tags.close, + } + + return r.render(state, template, context or {}) +end diff --git a/init.lua b/init.lua index 8ba662e..2168b52 100644 --- a/init.lua +++ b/init.lua @@ -1,7 +1,14 @@ +-- Remove prepending '@' & trailing 'init.lua' +script_path = string.sub(debug.getinfo(1).source, 2, -9) + +-- Include local libs in package.path +package.path = package.path .. ";"..script_path.."?.lua" + -- Load libraries json = require "json" lualdap = require "lualdap" math = require "math" +hige = require "hige" -- Set random key math.randomseed(os.time())