1
0
Fork 0
mirror of https://github.com/notwa/mm synced 2024-11-05 04:29:03 -08:00
mm/patch/argparse.lua

1181 lines
28 KiB
Lua
Raw Normal View History

-- The MIT License (MIT)
-- Copyright (c) 2013 - 2015 Peter Melnichenko
-- Permission is hereby granted, free of charge, to any person obtaining a copy of
-- this software and associated documentation files (the "Software"), to deal in
-- the Software without restriction, including without limitation the rights to
-- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
-- the Software, and to permit persons to whom the Software is furnished to do so,
-- subject to the following conditions:
-- The above copyright notice and this permission notice shall be included in all
-- copies or substantial portions of the Software.
-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
-- FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
-- COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
-- IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
-- CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
local function deep_update(t1, t2)
for k, v in pairs(t2) do
if type(v) == "table" then
v = deep_update({}, v)
end
t1[k] = v
end
return t1
end
-- A property is a tuple {name, callback}.
-- properties.args is number of properties that can be set as arguments
-- when calling an object.
local function class(prototype, properties, parent)
-- Class is the metatable of its instances.
local cl = {}
cl.__index = cl
if parent then
cl.__prototype = deep_update(deep_update({}, parent.__prototype), prototype)
else
cl.__prototype = prototype
end
if properties then
local names = {}
-- Create setter methods and fill set of property names.
for _, property in ipairs(properties) do
local name, callback = property[1], property[2]
cl[name] = function(self, value)
if not callback(self, value) then
self["_" .. name] = value
end
return self
end
names[name] = true
end
function cl.__call(self, ...)
-- When calling an object, if the first argument is a table,
-- interpret keys as property names, else delegate arguments
-- to corresponding setters in order.
if type((...)) == "table" then
for name, value in pairs((...)) do
if names[name] then
self[name](self, value)
end
end
else
local nargs = select("#", ...)
for i, property in ipairs(properties) do
if i > nargs or i > properties.args then
break
end
local arg = select(i, ...)
if arg ~= nil then
self[property[1]](self, arg)
end
end
end
return self
end
end
-- If indexing class fails, fallback to its parent.
local class_metatable = {}
class_metatable.__index = parent
function class_metatable.__call(self, ...)
-- Calling a class returns its instance.
-- Arguments are delegated to the instance.
local object = deep_update({}, self.__prototype)
setmetatable(object, self)
return object(...)
end
return setmetatable(cl, class_metatable)
end
local function typecheck(name, types, value)
for _, type_ in ipairs(types) do
if type(value) == type_ then
return true
end
end
error(("bad property '%s' (%s expected, got %s)"):format(name, table.concat(types, " or "), type(value)))
end
local function typechecked(name, ...)
local types = {...}
return {name, function(_, value) typecheck(name, types, value) end}
end
local multiname = {"name", function(self, value)
typecheck("name", {"string"}, value)
for alias in value:gmatch("%S+") do
self._name = self._name or alias
table.insert(self._aliases, alias)
end
-- Do not set _name as with other properties.
return true
end}
local function parse_boundaries(str)
if tonumber(str) then
return tonumber(str), tonumber(str)
end
if str == "*" then
return 0, math.huge
end
if str == "+" then
return 1, math.huge
end
if str == "?" then
return 0, 1
end
if str:match "^%d+%-%d+$" then
local min, max = str:match "^(%d+)%-(%d+)$"
return tonumber(min), tonumber(max)
end
if str:match "^%d+%+$" then
local min = str:match "^(%d+)%+$"
return tonumber(min), math.huge
end
end
local function boundaries(name)
return {name, function(self, value)
typecheck(name, {"number", "string"}, value)
local min, max = parse_boundaries(value)
if not min then
error(("bad property '%s'"):format(name))
end
self["_min" .. name], self["_max" .. name] = min, max
end}
end
local actions = {}
local option_action = {"action", function(_, value)
typecheck("action", {"function", "string"}, value)
if type(value) == "string" and not actions[value] then
error(("unknown action '%s'"):format(value))
end
end}
local option_init = {"init", function(self)
self._has_init = true
end}
local option_default = {"default", function(self, value)
if type(value) ~= "string" then
self._init = value
self._has_init = true
return true
end
end}
local add_help = {"add_help", function(self, value)
typecheck("add_help", {"boolean", "string", "table"}, value)
if self._has_help then
table.remove(self._options)
self._has_help = false
end
if value then
local help = self:flag()
:description "Show this help message and exit."
:action(function()
print(self:get_help())
os.exit(0)
end)
if value ~= true then
help = help(value)
end
if not help._name then
help "-h" "--help"
end
self._has_help = true
end
end}
local Parser = class({
_arguments = {},
_options = {},
_commands = {},
_mutexes = {},
_require_command = true,
_handle_options = true
}, {
args = 3,
typechecked("name", "string"),
typechecked("description", "string"),
typechecked("epilog", "string"),
typechecked("usage", "string"),
typechecked("help", "string"),
typechecked("require_command", "boolean"),
typechecked("handle_options", "boolean"),
typechecked("action", "function"),
typechecked("command_target", "string"),
add_help
})
local Command = class({
_aliases = {}
}, {
args = 3,
multiname,
typechecked("description", "string"),
typechecked("epilog", "string"),
typechecked("target", "string"),
typechecked("usage", "string"),
typechecked("help", "string"),
typechecked("require_command", "boolean"),
typechecked("handle_options", "boolean"),
typechecked("action", "function"),
typechecked("command_target", "string"),
add_help
}, Parser)
local Argument = class({
_minargs = 1,
_maxargs = 1,
_mincount = 1,
_maxcount = 1,
_defmode = "unused",
_show_default = true
}, {
args = 5,
typechecked("name", "string"),
typechecked("description", "string"),
option_default,
typechecked("convert", "function", "table"),
boundaries("args"),
typechecked("target", "string"),
typechecked("defmode", "string"),
typechecked("show_default", "boolean"),
typechecked("argname", "string", "table"),
option_action,
option_init
})
local Option = class({
_aliases = {},
_mincount = 0,
_overwrite = true
}, {
args = 6,
multiname,
typechecked("description", "string"),
option_default,
typechecked("convert", "function", "table"),
boundaries("args"),
boundaries("count"),
typechecked("target", "string"),
typechecked("defmode", "string"),
typechecked("show_default", "boolean"),
typechecked("overwrite", "boolean"),
typechecked("argname", "string", "table"),
option_action,
option_init
}, Argument)
function Argument:_get_argument_list()
local buf = {}
local i = 1
while i <= math.min(self._minargs, 3) do
local argname = self:_get_argname(i)
if self._default and self._defmode:find "a" then
argname = "[" .. argname .. "]"
end
table.insert(buf, argname)
i = i+1
end
while i <= math.min(self._maxargs, 3) do
table.insert(buf, "[" .. self:_get_argname(i) .. "]")
i = i+1
if self._maxargs == math.huge then
break
end
end
if i < self._maxargs then
table.insert(buf, "...")
end
return buf
end
function Argument:_get_usage()
local usage = table.concat(self:_get_argument_list(), " ")
if self._default and self._defmode:find "u" then
if self._maxargs > 1 or (self._minargs == 1 and not self._defmode:find "a") then
usage = "[" .. usage .. "]"
end
end
return usage
end
function actions.store_true(result, target)
result[target] = true
end
function actions.store_false(result, target)
result[target] = false
end
function actions.store(result, target, argument)
result[target] = argument
end
function actions.count(result, target, _, overwrite)
if not overwrite then
result[target] = result[target] + 1
end
end
function actions.append(result, target, argument, overwrite)
result[target] = result[target] or {}
table.insert(result[target], argument)
if overwrite then
table.remove(result[target], 1)
end
end
function actions.concat(result, target, arguments, overwrite)
if overwrite then
error("'concat' action can't handle too many invocations")
end
result[target] = result[target] or {}
for _, argument in ipairs(arguments) do
table.insert(result[target], argument)
end
end
function Argument:_get_action()
local action, init
if self._maxcount == 1 then
if self._maxargs == 0 then
action, init = "store_true", nil
else
action, init = "store", nil
end
else
if self._maxargs == 0 then
action, init = "count", 0
else
action, init = "append", {}
end
end
if self._action then
action = self._action
end
if self._has_init then
init = self._init
end
if type(action) == "string" then
action = actions[action]
end
return action, init
end
-- Returns placeholder for `narg`-th argument.
function Argument:_get_argname(narg)
local argname = self._argname or self:_get_default_argname()
if type(argname) == "table" then
return argname[narg]
else
return argname
end
end
function Argument:_get_default_argname()
return "<" .. self._name .. ">"
end
function Option:_get_default_argname()
return "<" .. self:_get_default_target() .. ">"
end
-- Returns label to be shown in the help message.
function Argument:_get_label()
return self._name
end
function Option:_get_label()
local variants = {}
local argument_list = self:_get_argument_list()
table.insert(argument_list, 1, nil)
for _, alias in ipairs(self._aliases) do
argument_list[1] = alias
table.insert(variants, table.concat(argument_list, " "))
end
return table.concat(variants, ", ")
end
function Command:_get_label()
return table.concat(self._aliases, ", ")
end
function Argument:_get_description()
if self._default and self._show_default then
if self._description then
return ("%s (default: %s)"):format(self._description, self._default)
else
return ("default: %s"):format(self._default)
end
else
return self._description or ""
end
end
function Command:_get_description()
return self._description or ""
end
function Option:_get_usage()
local usage = self:_get_argument_list()
table.insert(usage, 1, self._name)
usage = table.concat(usage, " ")
if self._mincount == 0 or self._default then
usage = "[" .. usage .. "]"
end
return usage
end
function Argument:_get_default_target()
return self._name
end
function Option:_get_default_target()
local res
for _, alias in ipairs(self._aliases) do
if alias:sub(1, 1) == alias:sub(2, 2) then
res = alias:sub(3)
break
end
end
res = res or self._name:sub(2)
return (res:gsub("-", "_"))
end
function Option:_is_vararg()
return self._maxargs ~= self._minargs
end
function Parser:_get_fullname()
local parent = self._parent
local buf = {self._name}
while parent do
table.insert(buf, 1, parent._name)
parent = parent._parent
end
return table.concat(buf, " ")
end
function Parser:_update_charset(charset)
charset = charset or {}
for _, command in ipairs(self._commands) do
command:_update_charset(charset)
end
for _, option in ipairs(self._options) do
for _, alias in ipairs(option._aliases) do
charset[alias:sub(1, 1)] = true
end
end
return charset
end
function Parser:argument(...)
local argument = Argument(...)
table.insert(self._arguments, argument)
return argument
end
function Parser:option(...)
local option = Option(...)
if self._has_help then
table.insert(self._options, #self._options, option)
else
table.insert(self._options, option)
end
return option
end
function Parser:flag(...)
return self:option():args(0)(...)
end
function Parser:command(...)
local command = Command():add_help(true)(...)
command._parent = self
table.insert(self._commands, command)
return command
end
function Parser:mutex(...)
local options = {...}
for i, option in ipairs(options) do
assert(getmetatable(option) == Option, ("bad argument #%d to 'mutex' (Option expected)"):format(i))
end
table.insert(self._mutexes, options)
return self
end
local max_usage_width = 70
local usage_welcome = "Usage: "
function Parser:get_usage()
if self._usage then
return self._usage
end
local lines = {usage_welcome .. self:_get_fullname()}
local function add(s)
if #lines[#lines]+1+#s <= max_usage_width then
lines[#lines] = lines[#lines] .. " " .. s
else
lines[#lines+1] = (" "):rep(#usage_welcome) .. s
end
end
-- This can definitely be refactored into something cleaner
local mutex_options = {}
local vararg_mutexes = {}
-- First, put mutexes which do not contain vararg options and remember those which do
for _, mutex in ipairs(self._mutexes) do
local buf = {}
local is_vararg = false
for _, option in ipairs(mutex) do
if option:_is_vararg() then
is_vararg = true
end
table.insert(buf, option:_get_usage())
mutex_options[option] = true
end
local repr = "(" .. table.concat(buf, " | ") .. ")"
if is_vararg then
table.insert(vararg_mutexes, repr)
else
add(repr)
end
end
-- Second, put regular options
for _, option in ipairs(self._options) do
if not mutex_options[option] and not option:_is_vararg() then
add(option:_get_usage())
end
end
-- Put positional arguments
for _, argument in ipairs(self._arguments) do
add(argument:_get_usage())
end
-- Put mutexes containing vararg options
for _, mutex_repr in ipairs(vararg_mutexes) do
add(mutex_repr)
end
for _, option in ipairs(self._options) do
if not mutex_options[option] and option:_is_vararg() then
add(option:_get_usage())
end
end
if #self._commands > 0 then
if self._require_command then
add("<command>")
else
add("[<command>]")
end
add("...")
end
return table.concat(lines, "\n")
end
local margin_len = 3
local margin_len2 = 25
local margin = (" "):rep(margin_len)
local margin2 = (" "):rep(margin_len2)
local function make_two_columns(s1, s2)
if s2 == "" then
return margin .. s1
end
s2 = s2:gsub("\n", "\n" .. margin2)
if #s1 < (margin_len2-margin_len) then
return margin .. s1 .. (" "):rep(margin_len2-margin_len-#s1) .. s2
else
return margin .. s1 .. "\n" .. margin2 .. s2
end
end
function Parser:get_help()
if self._help then
return self._help
end
local blocks = {self:get_usage()}
if self._description then
table.insert(blocks, self._description)
end
local labels = {"Arguments:", "Options:", "Commands:"}
for i, elements in ipairs{self._arguments, self._options, self._commands} do
if #elements > 0 then
local buf = {labels[i]}
for _, element in ipairs(elements) do
table.insert(buf, make_two_columns(element:_get_label(), element:_get_description()))
end
table.insert(blocks, table.concat(buf, "\n"))
end
end
if self._epilog then
table.insert(blocks, self._epilog)
end
return table.concat(blocks, "\n\n")
end
local function get_tip(context, wrong_name)
local context_pool = {}
local possible_name
local possible_names = {}
for name in pairs(context) do
if type(name) == "string" then
for i = 1, #name do
possible_name = name:sub(1, i - 1) .. name:sub(i + 1)
if not context_pool[possible_name] then
context_pool[possible_name] = {}
end
table.insert(context_pool[possible_name], name)
end
end
end
for i = 1, #wrong_name + 1 do
possible_name = wrong_name:sub(1, i - 1) .. wrong_name:sub(i + 1)
if context[possible_name] then
possible_names[possible_name] = true
elseif context_pool[possible_name] then
for _, name in ipairs(context_pool[possible_name]) do
possible_names[name] = true
end
end
end
local first = next(possible_names)
if first then
if next(possible_names, first) then
local possible_names_arr = {}
for name in pairs(possible_names) do
table.insert(possible_names_arr, "'" .. name .. "'")
end
table.sort(possible_names_arr)
return "\nDid you mean one of these: " .. table.concat(possible_names_arr, " ") .. "?"
else
return "\nDid you mean '" .. first .. "'?"
end
else
return ""
end
end
local ElementState = class({
invocations = 0
})
function ElementState:__call(state, element)
self.state = state
self.result = state.result
self.element = element
self.target = element._target or element:_get_default_target()
self.action, self.result[self.target] = element:_get_action()
return self
end
function ElementState:error(fmt, ...)
self.state:error(fmt, ...)
end
function ElementState:convert(argument)
local converter = self.element._convert
if converter then
local ok, err
if type(converter) == "function" then
ok, err = converter(argument)
else
ok = converter[argument]
end
if ok == nil then
self:error(err and "%s" or "malformed argument '%s'", err or argument)
end
argument = ok
end
return argument
end
function ElementState:default(mode)
return self.element._defmode:find(mode) and self.element._default
end
local function bound(noun, min, max, is_max)
local res = ""
if min ~= max then
res = "at " .. (is_max and "most" or "least") .. " "
end
local number = is_max and max or min
return res .. tostring(number) .. " " .. noun .. (number == 1 and "" or "s")
end
function ElementState:invoke(alias)
self.open = true
self.name = ("%s '%s'"):format(alias and "option" or "argument", alias or self.element._name)
self.overwrite = false
if self.invocations >= self.element._maxcount then
if self.element._overwrite then
self.overwrite = true
else
self:error("%s must be used %s", self.name, bound("time", self.element._mincount, self.element._maxcount, true))
end
else
self.invocations = self.invocations + 1
end
self.args = {}
if self.element._maxargs <= 0 then
self:close()
end
return self.open
end
function ElementState:pass(argument)
argument = self:convert(argument)
table.insert(self.args, argument)
if #self.args >= self.element._maxargs then
self:close()
end
return self.open
end
function ElementState:complete_invocation()
while #self.args < self.element._minargs do
self:pass(self.element._default)
end
end
function ElementState:close()
if self.open then
self.open = false
if #self.args < self.element._minargs then
if self:default("a") then
self:complete_invocation()
else
if #self.args == 0 then
if getmetatable(self.element) == Argument then
self:error("missing %s", self.name)
elseif self.element._maxargs == 1 then
self:error("%s requires an argument", self.name)
end
end
self:error("%s requires %s", self.name, bound("argument", self.element._minargs, self.element._maxargs))
end
end
local args = self.args
if self.element._maxargs <= 1 then
args = args[1]
end
if self.element._maxargs == 1 and self.element._minargs == 0 and self.element._mincount ~= self.element._maxcount then
args = self.args
end
self.action(self.result, self.target, args, self.overwrite)
end
end
local ParseState = class({
result = {},
options = {},
arguments = {},
argument_i = 1,
element_to_mutexes = {},
mutex_to_used_option = {},
command_actions = {}
})
function ParseState:__call(parser, error_handler)
self.parser = parser
self.error_handler = error_handler
self.charset = parser:_update_charset()
self:switch(parser)
return self
end
function ParseState:error(fmt, ...)
self.error_handler(self.parser, fmt:format(...))
end
function ParseState:switch(parser)
self.parser = parser
if parser._action then
table.insert(self.command_actions, {action = parser._action, name = parser._name})
end
for _, option in ipairs(parser._options) do
option = ElementState(self, option)
table.insert(self.options, option)
for _, alias in ipairs(option.element._aliases) do
self.options[alias] = option
end
end
for _, mutex in ipairs(parser._mutexes) do
for _, option in ipairs(mutex) do
if not self.element_to_mutexes[option] then
self.element_to_mutexes[option] = {}
end
table.insert(self.element_to_mutexes[option], mutex)
end
end
for _, argument in ipairs(parser._arguments) do
argument = ElementState(self, argument)
table.insert(self.arguments, argument)
argument:invoke()
end
self.handle_options = parser._handle_options
self.argument = self.arguments[self.argument_i]
self.commands = parser._commands
for _, command in ipairs(self.commands) do
for _, alias in ipairs(command._aliases) do
self.commands[alias] = command
end
end
end
function ParseState:get_option(name)
local option = self.options[name]
if not option then
self:error("unknown option '%s'%s", name, get_tip(self.options, name))
else
return option
end
end
function ParseState:get_command(name)
local command = self.commands[name]
if not command then
if #self.commands > 0 then
self:error("unknown command '%s'%s", name, get_tip(self.commands, name))
else
self:error("too many arguments")
end
else
return command
end
end
function ParseState:invoke(option, name)
self:close()
if self.element_to_mutexes[option.element] then
for _, mutex in ipairs(self.element_to_mutexes[option.element]) do
local used_option = self.mutex_to_used_option[mutex]
if used_option and used_option ~= option then
self:error("option '%s' can not be used together with %s", name, used_option.name)
else
self.mutex_to_used_option[mutex] = option
end
end
end
if option:invoke(name) then
self.option = option
end
end
function ParseState:pass(arg)
if self.option then
if not self.option:pass(arg) then
self.option = nil
end
elseif self.argument then
if not self.argument:pass(arg) then
self.argument_i = self.argument_i + 1
self.argument = self.arguments[self.argument_i]
end
else
local command = self:get_command(arg)
self.result[command._target or command._name] = true
if self.parser._command_target then
self.result[self.parser._command_target] = command._name
end
self:switch(command)
end
end
function ParseState:close()
if self.option then
self.option:close()
self.option = nil
end
end
function ParseState:finalize()
self:close()
for i = self.argument_i, #self.arguments do
local argument = self.arguments[i]
if #argument.args == 0 and argument:default("u") then
argument:complete_invocation()
else
argument:close()
end
end
if self.parser._require_command and #self.commands > 0 then
self:error("a command is required")
end
for _, option in ipairs(self.options) do
local name = option.name or ("option '%s'"):format(option.element._name)
if option.invocations == 0 then
if option:default("u") then
option:invoke(name)
option:complete_invocation()
option:close()
end
end
local mincount = option.element._mincount
if option.invocations < mincount then
if option:default("a") then
while option.invocations < mincount do
option:invoke(name)
option:close()
end
elseif option.invocations == 0 then
self:error("missing %s", name)
else
self:error("%s must be used %s", name, bound("time", mincount, option.element._maxcount))
end
end
end
for i = #self.command_actions, 1, -1 do
self.command_actions[i].action(self.result, self.command_actions[i].name)
end
end
function ParseState:parse(args)
for _, arg in ipairs(args) do
local plain = true
if self.handle_options then
local first = arg:sub(1, 1)
if self.charset[first] then
if #arg > 1 then
plain = false
if arg:sub(2, 2) == first then
if #arg == 2 then
self:close()
self.handle_options = false
else
local equals = arg:find "="
if equals then
local name = arg:sub(1, equals - 1)
local option = self:get_option(name)
if option.element._maxargs <= 0 then
self:error("option '%s' does not take arguments", name)
end
self:invoke(option, name)
self:pass(arg:sub(equals + 1))
else
local option = self:get_option(arg)
self:invoke(option, arg)
end
end
else
for i = 2, #arg do
local name = first .. arg:sub(i, i)
local option = self:get_option(name)
self:invoke(option, name)
if i ~= #arg and option.element._maxargs > 0 then
self:pass(arg:sub(i + 1))
break
end
end
end
end
end
end
if plain then
self:pass(arg)
end
end
self:finalize()
return self.result
end
function Parser:error(msg)
io.stderr:write(("%s\n\nError: %s\n"):format(self:get_usage(), msg))
os.exit(1)
end
-- Compatibility with strict.lua and other checkers:
local default_cmdline = rawget(_G, "arg") or {}
function Parser:_parse(args, error_handler)
return ParseState(self, error_handler):parse(args or default_cmdline)
end
function Parser:parse(args)
return self:_parse(args, self.error)
end
local function xpcall_error_handler(err)
return tostring(err) .. "\noriginal " .. debug.traceback("", 2):sub(2)
end
function Parser:pparse(args)
local parse_error
local ok, result = xpcall(function()
return self:_parse(args, function(_, err)
parse_error = err
error(err, 0)
end)
end, xpcall_error_handler)
if ok then
return true, result
elseif not parse_error then
error(result, 0)
else
return false, parse_error
end
end
return function(...)
return Parser(default_cmdline[0]):add_help(true)(...)
end