diff --git a/src/device_db_manager/device_db_manager.cpp b/src/device_db_manager/device_db_manager.cpp index 1a6ef76..994f4de 100644 --- a/src/device_db_manager/device_db_manager.cpp +++ b/src/device_db_manager/device_db_manager.cpp @@ -10,34 +10,56 @@ #include "log.h" -DeviceDBManager::DeviceDBManager(const std::string& dbPath) : db(nullptr) { - if (sqlite3_open(dbPath.c_str(), &db) != SQLITE_OK) { - LOG_ERROR("Failed to open database: {} with error msg {}", dbPath, - sqlite3_errmsg(db)); +DeviceDBManager::DeviceDBManager(const std::string& db_path) : db_(nullptr) { + if (sqlite3_open(db_path.c_str(), &db_) != SQLITE_OK) { + LOG_ERROR("Failed to open database, {}", sqlite3_errmsg(db_)); } - initDB(); + InitDB(); } DeviceDBManager::~DeviceDBManager() { - if (db) sqlite3_close(db); + if (db_) sqlite3_close(db_); } -void DeviceDBManager::initDB() { - const char* sql = +void DeviceDBManager::InitDB() { + const char* sql_devices = "CREATE TABLE IF NOT EXISTS devices (" "id INTEGER PRIMARY KEY AUTOINCREMENT," "device_id TEXT UNIQUE NOT NULL," + "password_salt TEXT NOT NULL," "password_hash TEXT NOT NULL);"; - char* errMsg = nullptr; - int rc = sqlite3_exec(db, sql, nullptr, nullptr, &errMsg); - if (rc != SQLITE_OK) { - LOG_ERROR("Failed to initialize DB: {}", errMsg); - sqlite3_free(errMsg); + const char* sql_seq = + "CREATE TABLE IF NOT EXISTS device_id_seq (" + "next_id INTEGER NOT NULL);"; + + const char* sql_seq_init = + "INSERT INTO device_id_seq (next_id) " + "SELECT 1 WHERE NOT EXISTS (SELECT 1 FROM device_id_seq);"; + + char* err_msg = nullptr; + + if (sqlite3_exec(db_, sql_devices, nullptr, nullptr, &err_msg) != SQLITE_OK) { + LOG_ERROR("Failed to create devices table: {}", err_msg); + sqlite3_free(err_msg); + return; + } + + if (sqlite3_exec(db_, sql_seq, nullptr, nullptr, &err_msg) != SQLITE_OK) { + LOG_ERROR("Failed to create device_id_seq table: {}", err_msg); + sqlite3_free(err_msg); + return; + } + + if (sqlite3_exec(db_, sql_seq_init, nullptr, nullptr, &err_msg) != + SQLITE_OK) { + LOG_ERROR("Failed to initialize device_id_seq: {}", err_msg); + sqlite3_free(err_msg); + return; } } -std::string DeviceDBManager::sha256(const std::string& str) { +std::string DeviceDBManager::Sha256(const std::string& str) { unsigned char hash[SHA256_DIGEST_LENGTH]; SHA256(reinterpret_cast(str.c_str()), str.size(), hash); @@ -48,90 +70,235 @@ std::string DeviceDBManager::sha256(const std::string& str) { return ss.str(); } -std::string DeviceDBManager::generateDeviceId() { +std::string DeviceDBManager::GenerateSalt() { + static const char charset[] = "0123456789ABCDEF"; static std::mt19937 rng(static_cast( std::chrono::steady_clock::now().time_since_epoch().count())); - std::uniform_int_distribution dist(0, 9); + std::uniform_int_distribution dist(0, 15); - std::string id; - for (int i = 0; i < 9; ++i) { - id += '0' + dist(rng); + std::string salt; + for (int i = 0; i < 16; ++i) { + salt += charset[dist(rng)]; } - return id; + return salt; } -std::string DeviceDBManager::addDevice(const std::string& password) { - std::string hash = sha256(password); +std::string DeviceDBManager::HashPasswordWithSalt(const std::string& salt, + const std::string& password) { + return Sha256(salt + password); +} - const int maxTry = 10; - for (int i = 0; i < maxTry; ++i) { - std::string deviceId = generateDeviceId(); +std::string DeviceDBManager::GenerateDeviceId() { + sqlite3_exec(db_, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); - const char* sql = - "INSERT INTO devices (device_id, password_hash) VALUES (?, ?);"; + sqlite3_stmt* stmt = nullptr; + int rc = sqlite3_prepare_v2(db_, "SELECT next_id FROM device_id_seq;", -1, + &stmt, nullptr); + if (rc != SQLITE_OK) { + sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr); + return {}; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(stmt); + sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr); + return {}; + } + int next_id = sqlite3_column_int(stmt, 0); + sqlite3_finalize(stmt); + + const int MIN_ID = 100000000; + const int MAX_ID = 999999999; + + std::mt19937 rng(next_id); + std::uniform_int_distribution dist(MIN_ID, MAX_ID); + int obfuscated_id = dist(rng); + + rc = sqlite3_prepare_v2(db_, "UPDATE device_id_seq SET next_id = ?;", -1, + &stmt, nullptr); + if (rc != SQLITE_OK) { + sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr); + return {}; + } + sqlite3_bind_int(stmt, 1, next_id + 1); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + + if (rc != SQLITE_DONE) { + sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr); + return {}; + } + + sqlite3_exec(db_, "COMMIT;", nullptr, nullptr, nullptr); + + char buf[10] = {0}; + snprintf(buf, sizeof(buf), "%09d", obfuscated_id); + + return std::string(buf); +} + +std::string DeviceDBManager::GeneratePassword() { + static const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + static std::mt19937 rng(static_cast( + std::chrono::steady_clock::now().time_since_epoch().count())); + std::uniform_int_distribution dist(0, + sizeof(charset) - 2); // exclude '\0' + + std::string pwd; + for (int i = 0; i < 6; ++i) { + pwd += charset[dist(rng)]; + } + return pwd; +} + +DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id, + const std::string& password) { + std::string hash = Sha256(password); + + if (!device_id.empty()) { + const char* select_sql = + "SELECT password_hash FROM devices WHERE device_id = ?;"; sqlite3_stmt* stmt = nullptr; - if (sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr) != SQLITE_OK) { - LOG_ERROR("Failed to prepare insert statement"); - return ""; + int rc = sqlite3_prepare_v2(db_, select_sql, -1, &stmt, nullptr); + if (rc != SQLITE_OK) { + LOG_ERROR("Failed to prepare select statement."); + return {}; } - sqlite3_bind_text(stmt, 1, deviceId.c_str(), -1, SQLITE_TRANSIENT); - sqlite3_bind_text(stmt, 2, hash.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT); + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + // Device exists + std::string old_hash = + reinterpret_cast(sqlite3_column_text(stmt, 0)); + sqlite3_finalize(stmt); + if (old_hash != hash) { + // Update password + const char* update_sql = + "UPDATE devices SET password_hash = ? WHERE device_id = ?;"; + if (sqlite3_prepare_v2(db_, update_sql, -1, &stmt, nullptr) != + SQLITE_OK) { + LOG_ERROR("Failed to prepare update statement."); + return {}; + } + sqlite3_bind_text(stmt, 1, hash.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 2, device_id.c_str(), -1, SQLITE_TRANSIENT); + rc = sqlite3_step(stmt); + sqlite3_finalize(stmt); + if (rc != SQLITE_DONE) { + LOG_ERROR("Failed to update password."); + return {}; + } + } + return {device_id, password, true}; // Same password or updated + } + sqlite3_finalize(stmt); + } + // Device not exists or device_id is empty — generate new + for (int i = 0; i < 10; ++i) { + std::string new_id = GenerateDeviceId(); + std::string new_pwd = GeneratePassword(); + + std::string salt = GenerateSalt(); + std::string hash = HashPasswordWithSalt(salt, new_pwd); + const char* insert_sql = + "INSERT INTO devices (device_id, password_hash, password_salt) VALUES " + "(?, ?, ?);"; + + sqlite3_stmt* stmt = nullptr; + if (sqlite3_prepare_v2(db_, insert_sql, -1, &stmt, nullptr) != SQLITE_OK) { + LOG_ERROR("Failed to prepare insert statement."); + return {}; + } + + sqlite3_bind_text(stmt, 1, new_id.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 2, hash.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 3, salt.c_str(), -1, SQLITE_TRANSIENT); int rc = sqlite3_step(stmt); sqlite3_finalize(stmt); if (rc == SQLITE_DONE) { - return deviceId; + return {new_id, new_pwd, false}; } else if (rc == SQLITE_CONSTRAINT) { + LOG_ERROR("{}:{} Insert failed: rc={}, err={}", new_id, new_pwd, rc, + sqlite3_errmsg(db_)); continue; } else { - LOG_ERROR("Failed to insert device: {}", sqlite3_errmsg(db)); - return ""; + LOG_ERROR("Insert device failed: {}", sqlite3_errmsg(db_)); + return {}; } } - LOG_ERROR("Failed to generate unique device ID after {} attempts.", maxTry); - return ""; + LOG_ERROR("Failed to generate unique device_id after multiple attempts."); + return {}; } -bool DeviceDBManager::verifyDevice(const std::string& deviceId, +bool DeviceDBManager::VerifyDevice(const std::string& device_id, const std::string& password) { - std::string hash = sha256(password); const char* sql = - "SELECT COUNT(*) FROM devices WHERE device_id = ? AND password_hash = ?;"; + "SELECT password_salt, password_hash FROM devices WHERE device_id = ?;"; sqlite3_stmt* stmt = nullptr; - if (sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr) != SQLITE_OK) { - LOG_ERROR("Failed to prepare verify statement."); + if (sqlite3_prepare_v2(db_, sql, -1, &stmt, nullptr) != SQLITE_OK) { return false; } - sqlite3_bind_text(stmt, 1, deviceId.c_str(), -1, SQLITE_TRANSIENT); - sqlite3_bind_text(stmt, 2, hash.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT); - bool found = false; + bool result = false; if (sqlite3_step(stmt) == SQLITE_ROW) { - int count = sqlite3_column_int(stmt, 0); - found = (count > 0); + std::string salt( + reinterpret_cast(sqlite3_column_text(stmt, 0))); + std::string stored_hash( + reinterpret_cast(sqlite3_column_text(stmt, 1))); + + std::string hash = HashPasswordWithSalt(salt, password); + if (hash == stored_hash) { + result = true; + } } sqlite3_finalize(stmt); - return found; + return result; } -bool DeviceDBManager::removeDevice(const std::string& deviceId) { - const char* sql = "DELETE FROM devices WHERE device_id = ?;"; +bool DeviceDBManager::UpdatePassword(const std::string& device_id, + const std::string& new_password) { + std::string salt = GenerateSalt(); + std::string hash = HashPasswordWithSalt(salt, new_password); + + const char* sql = + "UPDATE devices SET password_salt = ?, password_hash = ? WHERE device_id " + "= ?;"; sqlite3_stmt* stmt = nullptr; - if (sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr) != SQLITE_OK) { - LOG_ERROR("Failed to prepare delete statement."); + if (sqlite3_prepare_v2(db_, sql, -1, &stmt, nullptr) != SQLITE_OK) { return false; } - sqlite3_bind_text(stmt, 1, deviceId.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 1, salt.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 2, hash.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 3, device_id.c_str(), -1, SQLITE_TRANSIENT); bool success = (sqlite3_step(stmt) == SQLITE_DONE); sqlite3_finalize(stmt); return success; } + +bool DeviceDBManager::RemoveDevice(const std::string& device_id) { + const char* sql = "DELETE FROM devices WHERE device_id = ?;"; + + sqlite3_stmt* stmt = nullptr; + if (sqlite3_prepare_v2(db_, sql, -1, &stmt, nullptr) != SQLITE_OK) { + return false; + } + + sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT); + bool success = (sqlite3_step(stmt) == SQLITE_DONE); + sqlite3_finalize(stmt); + return success; +} diff --git a/src/device_db_manager/device_db_manager.h b/src/device_db_manager/device_db_manager.h index 71b31cd..add32d3 100644 --- a/src/device_db_manager/device_db_manager.h +++ b/src/device_db_manager/device_db_manager.h @@ -11,27 +11,41 @@ #include +struct DeviceCredential { + std::string device_id; + std::string password; + bool update; +}; + class DeviceDBManager { public: - explicit DeviceDBManager(const std::string& dbPath); + explicit DeviceDBManager(const std::string& db_path); ~DeviceDBManager(); DeviceDBManager(const DeviceDBManager&) = delete; DeviceDBManager& operator=(const DeviceDBManager&) = delete; - public: - std::string addDevice(const std::string& password); + DeviceCredential AddDevice(const std::string& device_id, + const std::string& password); - bool verifyDevice(const std::string& deviceId, const std::string& password); - bool removeDevice(const std::string& deviceId); + bool UpdatePassword(const std::string& device_id, + const std::string& new_password); + + bool VerifyDevice(const std::string& device_id, const std::string& password); + bool RemoveDevice(const std::string& device_id); private: - std::string sha256(const std::string& str); - void initDB(); - std::string generateDeviceId(); + void InitDB(); + std::string Sha256(const std::string& str); + std::string GenerateDeviceId(); + std::string GeneratePassword(); + std::string GenerateSalt(); + + std::string HashPasswordWithSalt(const std::string& salt, + const std::string& password); private: - sqlite3* db; + sqlite3* db_; }; #endif // _DEVICE_DB_MANAGER_H_ diff --git a/src/signal_server.cpp b/src/signal_server.cpp index b93d257..9006669 100644 --- a/src/signal_server.cpp +++ b/src/signal_server.cpp @@ -48,7 +48,7 @@ SignalServer::~SignalServer() {} bool SignalServer::on_open(websocketpp::connection_hdl hdl) { ws_connections_[hdl] = ws_connection_id_++; - device_db_manager_ = std::make_unique(""); + device_db_manager_ = std::make_unique("devices.db"); return true; } @@ -156,20 +156,44 @@ void SignalServer::on_message(websocketpp::connection_hdl hdl, switch (HASH_STRING_PIECE(type.c_str())) { case "login"_H: { std::string host_id = j["user_id"].get(); - if (host_id.empty()) { - host_id = ""; // todo - LOG_INFO("New client, assign id [{}] to it", host_id); + std::string password = j["password"].get(); + + DeviceCredential dev_cred = + device_db_manager_->AddDevice(host_id, password); + + std::string ret_host_id = dev_cred.device_id; + std::string ret_password = dev_cred.password; + bool update_password = dev_cred.update; + + bool update_success = + (!ret_host_id.empty() && !ret_password.empty()) && update_password; + bool login_success = + (!ret_host_id.empty() && !ret_password.empty()) && !update_password; + bool register_success = (!ret_host_id.empty() && !ret_password.empty()) && + (ret_host_id != host_id); + + bool success = true; + if (register_success) { + LOG_INFO("New client, assign id [{}] to it", ret_host_id); + success = transmission_manager_.BindUserToWsHandle(ret_host_id, hdl); + } else if (login_success) { + LOG_INFO("Receive login request with id [{}]", host_id); + success = transmission_manager_.BindUserToWsHandle(ret_host_id, hdl); + } else if (update_success) { + LOG_INFO("Client [{}] update password", ret_host_id); } - LOG_INFO("Receive login request with id [{}]", host_id); - bool success = transmission_manager_.BindUserToWsHandle(host_id, hdl); if (success) { - json message = { - {"type", "login"}, {"user_id", host_id}, {"status", "success"}}; + json message = {{"type", "login"}, + {"user_id", ret_host_id}, + {"pasword", ret_password}, + {"status", "success"}}; send_msg(hdl, message); } else { - json message = { - {"type", "login"}, {"user_id", host_id}, {"status", "fail"}}; + json message = {{"type", "login"}, + {"user_id", ret_host_id}, + {"pasword", ret_password}, + {"status", "fail"}}; send_msg(hdl, message); }