diff --git a/matrix/api.lua b/matrix/api.lua index fe7cb35..3e6d66a 100644 --- a/matrix/api.lua +++ b/matrix/api.lua @@ -34,7 +34,7 @@ local get_http_factory = function (http_factory) end end -- Try to import supplied HTTP client libraries, in order of preference. - local tries = http_factory and { http_factory } or { "chttp" } + local tries = http_factory and { http_factory } or { "chttp", "luasocket" } local errors = {} for i, http_factory in ipairs(tries) do local ok, factory = pcall(require, "matrix.factory." .. http_factory) diff --git a/matrix/factory/luasocket.lua b/matrix/factory/luasocket.lua new file mode 100644 index 0000000..af0ebca --- /dev/null +++ b/matrix/factory/luasocket.lua @@ -0,0 +1,85 @@ +#! /usr/bin/env lua +-- +-- luasocket.lua +-- Copyright (C) 2016 Adrian Perez +-- +-- Distributed under terms of the MIT license. +-- + +local urlescape = require "socket.url" .escape +local stringsource = require "ltn12" .source.string +local tablesink = require "ltn12" .sink.table + +local request_https = function (...) + request_https = require "ssl.https" .request + return request_https(...) +end + +local request_http = function (...) + request_http = require "socket.http" .request + return request_http(...) +end + +local function make_request(t) + if t.url:sub(1, #"https://") == "https://" then + return request_https(t) + else + return request_http(t) + end +end + +local function dict_to_query(d) + local r, i = {}, 0 + for name, value in pairs(d) do + i = i + 1 + r[i] = urlescape(name) .. "=" .. urlescape(value) + end + return table.concat(r, "&", 1, i) +end + +local httpclient = { + quote = require "socket.url" .escape, + unquote = require "socket.url" .unescape, +} +httpclient.__name = "matrix.factory.luasocket" +httpclient.__index = httpclient + +function httpclient:__tostring() + return self.__name +end + +function httpclient:request(log, method, url, query_args, body, headers) + do + local qs = dict_to_query(query_args) + if #qs > 0 then + url = url .. "?" .. qs + end + end + + log(">~> %s %s", method, url) + log(">>> %s", body) + + local source + if body and #body > 0 then + headers["content-length"] = #body + source = stringsource(body) + end + local result = {} + local r, c, h = make_request { + url = url, + method = method, + headers = headers, + source = source, + sink = tablesink(result), + } + local response = table.concat(result) + + log("<~< %d", c) + log("<<< %s", response) + + return c, h, response +end + +return function () + return setmetatable({}, httpclient) +end