local M = {}

local GREEK = {
  alpha=true, beta=true, gamma=true, delta=true, epsilon=true, zeta=true,
  eta=true, theta=true, iota=true, kappa=true, lambda=true, mu=true, nu=true,
  xi=true, pi=true, rho=true, sigma=true, tau=true, phi=true, chi=true,
  psi=true, omega=true, Gamma=true, Delta=true, Theta=true, Lambda=true,
  Xi=true, Pi=true, Sigma=true, Phi=true, Psi=true, Omega=true,
  partial=true, nabla=true,
}
local BIGOP = { sum="\\sum", prod="\\prod", int="\\int" }
local UNDEROP = { lim="\\lim" }
local FUNC = {
  sin=true, cos=true, tan=true, cot=true, sec=true, csc=true,
  arcsin=true, arccos=true, arctan=true,
  sinh=true, cosh=true, tanh=true, coth=true,
  ln=true, log=true, exp=true, det=true, dim=true, gcd=true,
  deg=true, ker=true, arg=true, hom=true, max=true, min=true, sup=true,
}
local INTOP = {
  int        = { pre="",                multi=true,  contour=false },
  contourint = { pre="\\oint",          multi=false, contour=true  },
  pvint      = { pre="\\mathrm{p.v.}\\!\\int", multi=false, contour=false },
  meanint    = { pre="\\fint",          multi=false, contour=false },
}

function M.differential(num, den)
  local n_d = num == "d" or num:match("^d%a") or num:match("^d%^")
  local d_d = den:match("^d%a")
  if n_d and d_d then
    local function roman_d(x)
      return (x:gsub("^d", "\\mathrm{d}"))
    end
    return roman_d(num), roman_d(den)
  end
  if num:match("^\\partial") and den:match("^\\partial") then
    return num, den
  end
  return num, den
end

function M.mathlite(s)
  local n = #s
  local pos = 1

  local function skipws() while pos <= n and s:sub(pos,pos):match("%s") do pos = pos + 1 end end

  local read_atom, read_scripts

  read_scripts = function(base)
    while true do
      skipws()
      local c = s:sub(pos, pos)
      if c ~= "^" and c ~= "_" then break end
      pos = pos + 1
      skipws()
      if s:sub(pos, pos) == "{" then
        local depth, j = 0, pos
        while j <= n do
          local d = s:sub(j, j)
          if d == "{" then depth = depth + 1
          elseif d == "}" then depth = depth - 1; if depth == 0 then break end end
          j = j + 1
        end
        base = base .. c .. s:sub(pos, j)
        pos = j + 1
      else
        local sign = ""
        local sc = s:sub(pos, pos)
        if sc == "-" or sc == "+" then sign = sc; pos = pos + 1 end
        local term = read_atom():gsub("^%((.*)%)$", "%1")
        base = base .. c .. "{" .. sign .. term .. "}"
      end
    end
    return base
  end

  read_atom = function()
    skipws()
    if pos > n then return "" end
    local c = s:sub(pos, pos)

    if c == "(" then
      local depth, j = 0, pos
      while j <= n do
        local d = s:sub(j,j)
        if d == "(" then depth = depth + 1
        elseif d == ")" then depth = depth - 1; if depth == 0 then break end end
        j = j + 1
      end
      local inner = s:sub(pos + 1, j - 1)
      pos = j + 1
      return "(" .. M.mathlite(inner) .. ")"
    end

    if c == "{" then
      local depth, j = 0, pos
      while j <= n do
        local d = s:sub(j, j)
        if d == "{" then depth = depth + 1
        elseif d == "}" then depth = depth - 1; if depth == 0 then break end end
        j = j + 1
      end
      local inner = s:sub(pos + 1, j - 1)
      pos = j + 1
      return "{" .. M.mathlite(inner) .. "}"
    end

    if c == "\\" then
      local j = pos + 1
      if s:sub(j,j):match("%a") then
        while j <= n and s:sub(j,j):match("%a") do j = j + 1 end
      else
        j = pos + 2
      end
      local cmd = s:sub(pos, j - 1)
      pos = j
      return cmd
    end

    if c:match("%a") then
      local word = s:sub(pos):match("^(%a+)")
      local after = pos + #word
      if word == "sqrt" and s:sub(after, after) == "(" then
        pos = after
        local arg = read_atom()
        arg = arg:gsub("^%((.*)%)$", "%1")
        return "\\sqrt{" .. arg .. "}"
      elseif word == "abs" and s:sub(after, after) == "(" then
        pos = after
        local arg = read_atom():gsub("^%((.*)%)$", "%1")
        return "\\left|" .. arg .. "\\right|"
      elseif word == "norm" and s:sub(after, after) == "(" then
        pos = after
        local arg = read_atom():gsub("^%((.*)%)$", "%1")
        return "\\left\\|" .. arg .. "\\right\\|"
      elseif word == "vec" and s:sub(after, after) == "(" then
        pos = after
        local arg = read_atom():gsub("^%((.*)%)$", "%1")
        return "\\overrightarrow{" .. arg .. "}"
      elseif UNDEROP[word] and s:sub(after, after) == "(" then
        pos = after
        local grp = read_atom():gsub("^%((.*)%)$", "%1")
        grp = grp:gsub("%->", "\\to ")
        local rest = s:sub(pos)
        local body_raw, tail = rest, ""
        local eqpos = rest:find("=", 1, true)
        if eqpos then body_raw = rest:sub(1, eqpos - 1); tail = rest:sub(eqpos) end
        pos = n + 1
        local body = M.mathlite(body_raw:gsub("^%s+",""):gsub("%s+$",""))
        local op = "{\\displaystyle " .. UNDEROP[word] .. "\\limits_{"
          .. grp:gsub("^%s+",""):gsub("%s+$","") .. "} " .. body .. "}"
        return op .. (tail ~= "" and (" " .. M.mathlite(tail)) or "")
      elseif INTOP[word] and s:sub(after, after) == "(" then
        pos = after
        local spec = read_atom():gsub("^%((.*)%)$", "%1")
        local rest = s:sub(pos)
        local body_raw, tail = rest, ""
        local eqpos = rest:find("=", 1, true)
        if eqpos then
          body_raw = rest:sub(1, eqpos - 1)
          tail = rest:sub(eqpos)
        end
        pos = n + 1
        local body = M.mathlite(body_raw:gsub("^%s+",""):gsub("%s+$",""))
        local op = INTOP[word]
        local domains = {}
        for piece in (spec .. ";"):gmatch("(.-);") do
          domains[#domains+1] = piece:gsub("^%s+",""):gsub("%s+$","")
        end
        if #domains == 0 then domains = { spec } end
        local symbols, diffs = {}, {}
        for _, dom in ipairs(domains) do
          local var, lo, hi = dom:match("^(%S+)%s*=%s*(.-)%s*,%s*(.-)%s*$")
          if var then
            symbols[#symbols+1] = "\\int_{" .. lo .. "}^{" .. hi .. "}"
            table.insert(diffs, 1, "\\,\\mathrm{d}" .. var)
          elseif dom:match("^%l$") then
            symbols[#symbols+1] = "\\int"
            table.insert(diffs, 1, "\\,\\mathrm{d}" .. dom)
          else
            symbols[#symbols+1] = "\\iint_{" .. dom .. "}"
            table.insert(diffs, 1, "\\,\\mathrm{d}\\omega")
          end
        end
        local head
        if op.contour then
          head = op.pre .. "_{" .. spec .. "}"
          diffs = {}
          local v = spec:match("^%a$") and spec or "z"
          diffs[1] = "\\,\\mathrm{d}" .. (spec:match("^%l$") and spec or "z")
        elseif op.pre ~= "" then
          local var, lo, hi = spec:match("^(%S+)%s*=%s*(.-)%s*,%s*(.-)%s*$")
          if var then
            head = op.pre .. "_{" .. lo .. "}^{" .. hi .. "}"
            diffs = { "\\,\\mathrm{d}" .. var }
          else
            head = op.pre .. (spec:match("^%l$") and "" or ("_{" .. spec .. "}"))
            diffs = { "\\,\\mathrm{d}" .. (spec:match("^%l$") and spec or "\\omega") }
          end
        else
          head = table.concat(symbols)
        end
        return "{\\displaystyle " .. head .. " " .. body .. table.concat(diffs) .. "}" .. (tail ~= "" and (" " .. M.mathlite(tail)) or "")
      elseif BIGOP[word] and s:sub(after, after) == "(" then
        pos = after
        local grp = read_atom()
        grp = grp:gsub("^%((.*)%)$", "%1")
        local lo, hi = grp:match("^(.-),(.*)$")
        local sub
        if lo then
          sub = BIGOP[word] .. "_{" .. lo:gsub("^%s+",""):gsub("%s+$","")
              .. "}^{" .. hi:gsub("^%s+",""):gsub("%s+$","") .. "}"
        else
          sub = BIGOP[word] .. "_{" .. grp .. "}"
        end
        local rest = s:sub(pos)
        local body_raw, tail = rest, ""
        local eqpos = rest:find("=", 1, true)
        if eqpos then body_raw = rest:sub(1, eqpos - 1); tail = rest:sub(eqpos) end
        pos = n + 1
        local body = M.mathlite(body_raw:gsub("^%s+",""):gsub("%s+$",""))
        return "{\\displaystyle " .. sub .. " " .. body .. "}"
          .. (tail ~= "" and (" " .. M.mathlite(tail)) or "")
      elseif FUNC[word] then
        pos = after
        if s:sub(pos, pos) == "(" then
          local arg = read_atom()
          return "\\" .. word .. arg
        end
        return "\\" .. word .. " "
      elseif word == "inf" then
        pos = after; return "\\infty "
      elseif GREEK[word] then
        pos = after; return "\\" .. word .. " "
      else
        pos = after; return word
      end
    end

    if c:match("%d") then
      local num = s:sub(pos):match("^([%d.]+)")
      pos = pos + #num
      if M.decsep and M.decsep ~= "." then
        num = num:gsub("%.", M.decsep)
      end
      return num
    end

    pos = pos + 1
    return c
  end

  local out = {}
  while pos <= n do
    skipws()
    if pos > n then break end
    local c = s:sub(pos, pos)

    if c == "*" then out[#out+1] = " \\times "; pos = pos + 1
    elseif c == "+" and s:sub(pos+1,pos+1) == "-" then out[#out+1] = " \\pm "; pos = pos + 2
    elseif c == "<" and s:sub(pos+1,pos+1) == "=" then out[#out+1] = " \\leq "; pos = pos + 2
    elseif c == ">" and s:sub(pos+1,pos+1) == "=" then out[#out+1] = " \\geq "; pos = pos + 2
    elseif c == "!" and s:sub(pos+1,pos+1) == "=" then out[#out+1] = " \\neq "; pos = pos + 2
    elseif c == "/" then
      local num = table.remove(out) or ""
      num = num:gsub("^%((.*)%)$", "%1"):gsub("^{(.*)}$", "%1")
      pos = pos + 1
      local den = read_scripts(read_atom())
      den = den:gsub("^%((.*)%)$", "%1"):gsub("^{(.*)}$", "%1")
      num, den = M.differential(num, den)
      local frac = "\\frac{" .. num .. "}{" .. den .. "}"
      skipws()
      while s:sub(pos, pos) == "/" do
        pos = pos + 1
        local nxt = read_scripts(read_atom()):gsub("^%((.*)%)$", "%1"):gsub("^{(.*)}$", "%1")
        frac = "\\frac{" .. frac .. "}{" .. nxt .. "}"
        skipws()
      end
      out[#out+1] = frac
    elseif c == "^" or c == "_" then
      pos = pos + 1
      skipws()
      if s:sub(pos, pos) == "{" then
        local depth, j = 0, pos
        while j <= n do
          local d = s:sub(j,j)
          if d == "{" then depth = depth + 1
          elseif d == "}" then depth = depth - 1; if depth == 0 then break end end
          j = j + 1
        end
        out[#out+1] = c .. s:sub(pos, j)
        pos = j + 1
      else
        local sign = ""
        local sc = s:sub(pos, pos)
        if sc == "-" or sc == "+" then sign = sc; pos = pos + 1 end
        local term = read_atom()
        term = term:gsub("^%((.*)%)$", "%1")
        out[#out+1] = c .. "{" .. sign .. term .. "}"
      end
    elseif c == "+" or c == "-" or c == "=" or c == "<" or c == ">"
        or c == "," or c == ")" then
      out[#out+1] = c; pos = pos + 1
    else
      out[#out+1] = read_scripts(read_atom())
    end
  end
  return table.concat(out)
end

return M

