Fix crashes when logid starts as root

If logid scans every device, it will either SIGSEGV or not work at all.
This commit should fix bug #100.
This commit is contained in:
pixl 2020-07-13 01:14:15 -04:00
parent 8f073d66c3
commit dde2993223
No known key found for this signature in database
GPG Key ID: 1866C148CD593B6E
5 changed files with 119 additions and 41 deletions

View File

@ -19,6 +19,8 @@
#include "DeviceMonitor.h" #include "DeviceMonitor.h"
#include "../../util/task.h" #include "../../util/task.h"
#include "../../util/log.h" #include "../../util/log.h"
#include "RawDevice.h"
#include "../hidpp/Device.h"
#include <thread> #include <thread>
#include <system_error> #include <system_error>
@ -99,7 +101,13 @@ void DeviceMonitor::run()
if (action == "add") if (action == "add")
task::spawn([this, name=devnode]() { task::spawn([this, name=devnode]() {
auto supported_reports = backend::hidpp::getSupportedReports(
RawDevice::getReportDescriptor(name));
if(supported_reports)
this->addDevice(name); this->addDevice(name);
else
logPrintf(DEBUG, "Unsupported device %s ignored",
name.c_str());
}, [name=devnode](std::exception& e){ }, [name=devnode](std::exception& e){
logPrintf(WARN, "Error adding device %s: %s", logPrintf(WARN, "Error adding device %s: %s",
name.c_str(), e.what()); name.c_str(), e.what());
@ -158,7 +166,13 @@ void DeviceMonitor::enumerate()
udev_device_unref(device); udev_device_unref(device);
task::spawn([this, name=devnode]() { task::spawn([this, name=devnode]() {
auto supported_reports = backend::hidpp::getSupportedReports(
RawDevice::getReportDescriptor(name));
if(supported_reports)
this->addDevice(name); this->addDevice(name);
else
logPrintf(DEBUG, "Unsupported device %s ignored",
name.c_str());
}, [name=devnode](std::exception& e){ }, [name=devnode](std::exception& e){
logPrintf(WARN, "Error adding device %s: %s", logPrintf(WARN, "Error adding device %s: %s",
name.c_str(), e.what()); name.c_str(), e.what());

View File

@ -94,20 +94,7 @@ RawDevice::RawDevice(std::string path) : _path (std::move(path)),
} }
_name.assign(name_buf, ret - 1); _name.assign(name_buf, ret - 1);
hidraw_report_descriptor _rdesc{}; _rdesc = getReportDescriptor(_fd);
if (-1 == ::ioctl(_fd, HIDIOCGRDESCSIZE, &_rdesc.size)) {
int err = errno;
::close(_fd);
throw std::system_error(err, std::system_category(),
"RawDevice HIDIOCGRDESCSIZE failed");
}
if (-1 == ::ioctl(_fd, HIDIOCGRDESC, &_rdesc)) {
int err = errno;
::close(_fd);
throw std::system_error(err, std::system_category(),
"RawDevice HIDIOCGRDESC failed");
}
rdesc = std::vector<uint8_t>(_rdesc.value, _rdesc.value + _rdesc.size);
if (-1 == ::pipe(_pipe)) { if (-1 == ::pipe(_pipe)) {
int err = errno; int err = errno;
@ -148,6 +135,41 @@ uint16_t RawDevice::productId() const
return _pid; return _pid;
} }
std::vector<uint8_t> RawDevice::getReportDescriptor(std::string path)
{
int fd = ::open(path.c_str(), O_RDWR);
if (fd == -1)
throw std::system_error(errno, std::system_category(),
"open failed");
auto rdesc = getReportDescriptor(fd);
::close(fd);
return rdesc;
}
std::vector<uint8_t> RawDevice::getReportDescriptor(int fd)
{
hidraw_report_descriptor rdesc{};
if (-1 == ::ioctl(fd, HIDIOCGRDESCSIZE, &rdesc.size)) {
int err = errno;
::close(fd);
throw std::system_error(err, std::system_category(),
"RawDevice HIDIOCGRDESCSIZE failed");
}
if (-1 == ::ioctl(fd, HIDIOCGRDESC, &rdesc)) {
int err = errno;
::close(fd);
throw std::system_error(err, std::system_category(),
"RawDevice HIDIOCGRDESC failed");
}
return std::vector<uint8_t>(rdesc.value, rdesc.value + rdesc.size);
}
std::vector<uint8_t> RawDevice::reportDescriptor() const
{
return _rdesc;
}
std::vector<uint8_t> RawDevice::sendReport(const std::vector<uint8_t>& report) std::vector<uint8_t> RawDevice::sendReport(const std::vector<uint8_t>& report)
{ {
/* If the listener will stop, handle I/O manually. /* If the listener will stop, handle I/O manually.
@ -157,33 +179,47 @@ std::vector<uint8_t> RawDevice::sendReport(const std::vector<uint8_t>& report)
std::unique_lock<std::mutex> lock(send_report); std::unique_lock<std::mutex> lock(send_report);
std::condition_variable cv; std::condition_variable cv;
bool top_of_queue = false; bool top_of_queue = false;
std::packaged_task<std::vector<uint8_t>()> task( [this, report, &cv, auto task = std::make_shared<std::packaged_task<std::vector<uint8_t>()>>
&top_of_queue] () { ( [this, report, &cv, &top_of_queue] () {
top_of_queue = true; top_of_queue = true;
cv.notify_all(); cv.notify_all();
return this->_respondToReport(report); return this->_respondToReport(report);
}); });
auto f = task.get_future(); auto f = task->get_future();
_io_queue.push(&task); _io_queue.push(task);
cv.wait(lock, [&top_of_queue]{ return top_of_queue; }); cv.wait(lock, [&top_of_queue]{ return top_of_queue; });
auto status = f.wait_for(global_config->ioTimeout()); auto status = f.wait_for(global_config->ioTimeout());
if(status == std::future_status::timeout) { if(status == std::future_status::timeout) {
_continue_respond = false; _continue_respond = false;
throw TimeoutError(); interruptRead();
return f.get(); // Expecting an error, but it could work
} }
return f.get(); return f.get();
} }
else { else {
std::vector<uint8_t> response; std::vector<uint8_t> response;
std::exception_ptr _exception;
std::shared_ptr<task> t = std::make_shared<task>( std::shared_ptr<task> t = std::make_shared<task>(
[this, report, &response]() { [this, report, &response]() {
response = _respondToReport(report); response = _respondToReport(report);
}, [&_exception](std::exception& e) {
try {
throw e;
} catch(std::exception& e) {
_exception = std::make_exception_ptr(e);
}
}); });
global_workqueue->queue(t); global_workqueue->queue(t);
t->waitStart(); t->waitStart();
auto status = t->waitFor(global_config->ioTimeout()); auto status = t->waitFor(global_config->ioTimeout());
if(_exception)
std::rethrow_exception(_exception);
if(status == std::future_status::timeout) { if(status == std::future_status::timeout) {
_continue_respond = false; _continue_respond = false;
interruptRead();
t->wait();
if(_exception)
std::rethrow_exception(_exception);
throw TimeoutError(); throw TimeoutError();
} else } else
return response; return response;
@ -196,12 +232,13 @@ void RawDevice::sendReportNoResponse(const std::vector<uint8_t>& report)
/* If the listener will stop, handle I/O manually. /* If the listener will stop, handle I/O manually.
* Otherwise, push to queue and wait for result. */ * Otherwise, push to queue and wait for result. */
if(_continue_listen) { if(_continue_listen) {
std::packaged_task<std::vector<uint8_t>()> task([this, report]() { auto task = std::make_shared<std::packaged_task<std::vector<uint8_t>()>>
([this, report]() {
this->_sendReport(report); this->_sendReport(report);
return std::vector<uint8_t>(); return std::vector<uint8_t>();
}); });
auto f = task.get_future(); auto f = task->get_future();
_io_queue.push(&task); _io_queue.push(task);
f.get(); f.get();
} }
else else
@ -213,9 +250,17 @@ std::vector<uint8_t> RawDevice::_respondToReport
{ {
_sendReport(request); _sendReport(request);
_continue_respond = true; _continue_respond = true;
auto start_point = std::chrono::steady_clock::now();
while(_continue_respond) { while(_continue_respond) {
std::vector<uint8_t> response; std::vector<uint8_t> response;
_readReport(response, MAX_DATA_LENGTH); auto current_point = std::chrono::steady_clock::now();
auto timeout = global_config->ioTimeout() - std::chrono::duration_cast
<std::chrono::milliseconds>(current_point - start_point);
if(timeout.count() <= 0)
throw TimeoutError();
_readReport(response, MAX_DATA_LENGTH, timeout);
// All reports have the device index at byte 2 // All reports have the device index at byte 2
if(response[1] != request[1]) { if(response[1] != request[1]) {
@ -285,21 +330,27 @@ int RawDevice::_sendReport(const std::vector<uint8_t>& report)
return ret; return ret;
} }
int RawDevice::_readReport(std::vector<uint8_t>& report, std::size_t maxDataLength) int RawDevice::_readReport(std::vector<uint8_t> &report,
std::size_t maxDataLength)
{
return _readReport(report, maxDataLength, global_config->ioTimeout());
}
int RawDevice::_readReport(std::vector<uint8_t> &report,
std::size_t maxDataLength, std::chrono::milliseconds timeout)
{ {
std::lock_guard<std::mutex> lock(_dev_io); std::lock_guard<std::mutex> lock(_dev_io);
int ret; int ret;
report.resize(maxDataLength); report.resize(maxDataLength);
timeval timeout{}; timeval timeout_tv{};
timeout.tv_sec = duration_cast<seconds>(global_config->ioTimeout()) timeout_tv.tv_sec = duration_cast<seconds>(global_config->ioTimeout())
.count(); .count();
timeout.tv_usec = duration_cast<microseconds>( timeout_tv.tv_usec = duration_cast<microseconds>(
global_config->ioTimeout()).count() % global_config->ioTimeout()).count() %
duration_cast<microseconds>(seconds(1)).count(); duration_cast<microseconds>(seconds(1)).count();
auto timeout_ms = duration_cast<milliseconds>( auto timeout_ms = duration_cast<milliseconds>(timeout).count();
global_config->ioTimeout()).count();
fd_set fds; fd_set fds;
do { do {
@ -309,7 +360,7 @@ int RawDevice::_readReport(std::vector<uint8_t>& report, std::size_t maxDataLeng
ret = select(std::max(_fd, _pipe[0]) + 1, ret = select(std::max(_fd, _pipe[0]) + 1,
&fds, nullptr, nullptr, &fds, nullptr, nullptr,
(timeout_ms > 0 ? nullptr : &timeout)); (timeout_ms > 0 ? nullptr : &timeout_tv));
} while(ret == -1 && errno == EINTR); } while(ret == -1 && errno == EINTR);
if(ret == -1) if(ret == -1)

View File

@ -47,7 +47,9 @@ namespace raw
uint16_t vendorId() const; uint16_t vendorId() const;
uint16_t productId() const; uint16_t productId() const;
std::vector<uint8_t> reportDescriptor() const { return rdesc; } static std::vector<uint8_t> getReportDescriptor(std::string path);
static std::vector<uint8_t> getReportDescriptor(int fd);
std::vector<uint8_t> reportDescriptor() const;
std::vector<uint8_t> sendReport(const std::vector<uint8_t>& report); std::vector<uint8_t> sendReport(const std::vector<uint8_t>& report);
void sendReportNoResponse(const std::vector<uint8_t>& report); void sendReportNoResponse(const std::vector<uint8_t>& report);
@ -72,7 +74,7 @@ namespace raw
uint16_t _vid; uint16_t _vid;
uint16_t _pid; uint16_t _pid;
std::string _name; std::string _name;
std::vector<uint8_t> rdesc; std::vector<uint8_t> _rdesc;
std::atomic<bool> _continue_listen; std::atomic<bool> _continue_listen;
std::atomic<bool> _continue_respond; std::atomic<bool> _continue_respond;
@ -86,11 +88,14 @@ namespace raw
/* These will only be used internally and processed with a queue */ /* These will only be used internally and processed with a queue */
int _sendReport(const std::vector<uint8_t>& report); int _sendReport(const std::vector<uint8_t>& report);
int _readReport(std::vector<uint8_t>& report, std::size_t maxDataLength); int _readReport(std::vector<uint8_t>& report, std::size_t maxDataLength);
int _readReport(std::vector<uint8_t>& report, std::size_t maxDataLength,
std::chrono::milliseconds timeout);
std::vector<uint8_t> _respondToReport(const std::vector<uint8_t>& std::vector<uint8_t> _respondToReport(const std::vector<uint8_t>&
request); request);
mutex_queue<std::packaged_task<std::vector<uint8_t>()>*> _io_queue; mutex_queue<std::shared_ptr<std::packaged_task<std::vector<uint8_t>()>>>
_io_queue;
}; };
}}} }}}

View File

@ -31,7 +31,7 @@ task::task(const std::function<void()>& function,
} catch(std::exception& e) { } catch(std::exception& e) {
(*_exception_handler)(e); (*_exception_handler)(e);
} }
}) }), _future (_task_pkg.get_future())
{ {
} }
@ -41,6 +41,7 @@ void task::run()
_status_cv.notify_all(); _status_cv.notify_all();
_task_pkg(); _task_pkg();
_status = Completed; _status = Completed;
_status_cv.notify_all();
} }
task::Status task::getStatus() task::Status task::getStatus()
@ -50,7 +51,13 @@ task::Status task::getStatus()
void task::wait() void task::wait()
{ {
_task_pkg.get_future().wait(); if(_future.valid())
_future.wait();
else {
std::mutex wait_start;
std::unique_lock<std::mutex> lock(wait_start);
_status_cv.wait(lock, [this](){ return _status == Completed; });
}
} }
void task::waitStart() void task::waitStart()
@ -62,7 +69,7 @@ void task::waitStart()
std::future_status task::waitFor(std::chrono::milliseconds ms) std::future_status task::waitFor(std::chrono::milliseconds ms)
{ {
return _task_pkg.get_future().wait_for(ms); return _future.wait_for(ms);
} }
void task::spawn(const std::function<void ()>& function, void task::spawn(const std::function<void ()>& function,

View File

@ -62,6 +62,7 @@ namespace logid
std::atomic<Status> _status; std::atomic<Status> _status;
std::condition_variable _status_cv; std::condition_variable _status_cv;
std::packaged_task<void()> _task_pkg; std::packaged_task<void()> _task_pkg;
std::future<void> _future;
}; };
} }