Split MQTT socket handling from BulbManager

This commit is contained in:
2023-11-11 17:57:04 +01:00
parent 0c8f09e825
commit 4294d5c79e
10 changed files with 726 additions and 345 deletions

View File

@ -4,16 +4,17 @@ version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1"
serde = { version = "1", features = ["derive"] }
clap = { version = "3.2", features = ["derive"] }
log = "0.4"
mqtt-protocol = { version = "0.11", features = ["tokio"] }
pretty_env_logger = "0.4.0"
serde_json = "1"
tokio = { version = "1.19.2", features = ["net", "rt-multi-thread", "macros", "sync", "io-std", "io-util", "time"] }
tokio = { version = "1.19.2", features = ["net", "rt-multi-thread", "macros", "sync", "io-std", "io-util", "time", "tracing"] }
toml = "0.5"
uuid = { version = "1.1", features = ["v4"] }
async-trait = "0.1.74"
eyre = "0.6.8"
[dependencies.lighter_lib]
path = "../lib"

View File

@ -3,3 +3,4 @@ extern crate log;
pub mod manager;
pub mod mqtt_conf;
pub mod provider;

View File

@ -5,6 +5,9 @@ use clap::Parser;
use lighter_lib::BulbColor;
use lighter_manager::manager::{BulbCommand, BulbManager, BulbSelector, BulbsConfig};
use lighter_manager::mqtt_conf::MqttConfig;
use lighter_manager::provider::mock::BulbsMock;
use lighter_manager::provider::mqtt::BulbsMqtt;
use lighter_manager::provider::BulbProvider;
use log::LevelFilter;
use serde::Deserialize;
use std::error::Error;
@ -15,7 +18,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tokio::io::{stdin, AsyncBufReadExt, BufReader};
async fn ask<T>(for_what: &str) -> anyhow::Result<T>
async fn ask<T>(for_what: &str) -> eyre::Result<T>
where
T: FromStr,
<T as FromStr>::Err: Display + Error + Send + Sync + 'static,
@ -40,6 +43,9 @@ struct Opt {
#[clap(short, long)]
config: PathBuf,
#[clap(short, long)]
mock: bool,
}
#[derive(Deserialize)]
@ -51,7 +57,7 @@ struct Config {
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
async fn main() -> eyre::Result<()> {
let opt = Opt::parse();
let log_level = match opt.verbose {
@ -69,7 +75,12 @@ async fn main() -> anyhow::Result<()> {
let config: String = fs::read_to_string(&opt.config)?;
let config: Config = toml::from_str(&config)?;
let (commands, state) = BulbManager::launch(config.bulbs, config.mqtt).await?;
let provider: Box<dyn BulbProvider + Send + Sync> = if opt.mock {
Box::new(BulbsMock::new(config.bulbs.clone()))
} else {
Box::new(BulbsMqtt::new(config.bulbs.clone(), config.mqtt.clone()))
};
let manager = BulbManager::launch(config.bulbs, provider).await?;
loop {
let command: String = ask("command").await?;
@ -77,13 +88,13 @@ async fn main() -> anyhow::Result<()> {
let command = match command.as_str() {
"power" => {
let power: bool = ask("on").await?;
commands.send(BulbCommand::SetPower(BulbSelector::All, power))
manager.send_command(BulbCommand::SetPower(BulbSelector::All, power))
}
"kelvin" => {
let t: f32 = ask("temperature").await?;
let b: f32 = ask("brightness").await?;
commands.send(BulbCommand::SetColor(
manager.send_command(BulbCommand::SetColor(
BulbSelector::All,
BulbColor::kelvin(t, b),
))
@ -93,7 +104,7 @@ async fn main() -> anyhow::Result<()> {
let s: f32 = ask("saturation").await?;
let b: f32 = ask("brightness").await?;
commands.send(BulbCommand::SetColor(
manager.send_command(BulbCommand::SetColor(
BulbSelector::All,
BulbColor::hsb(h, s, b),
))
@ -104,11 +115,14 @@ async fn main() -> anyhow::Result<()> {
}
};
let notify = state.notify_on_change();
command.await?;
let notify = manager.notify_on_change();
info!("1");
command.await;
info!("2");
notify.await;
info!("3");
let bulbs = state.bulbs().await;
let bulbs = manager.bulbs().await;
info!("bulbs: {bulbs:?}");
}
}

View File

@ -1,22 +1,14 @@
use crate::mqtt_conf::MqttConfig;
use crate::provider::{BulbProvider, BulbUpdate};
use lighter_lib::{BulbColor, BulbId, BulbMode};
use mqtt::{
packet::{PublishPacket, QoSWithPacketIdentifier, SubscribePacket, VariablePacket},
Encodable, QualityOfService, TopicFilter, TopicName,
};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::str::from_utf8;
use std::sync::Arc;
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
sync::{futures::Notified, mpsc, Notify, RwLock, RwLockReadGuard},
task,
time::{sleep, Duration},
};
#[derive(Debug)]
#[derive(Clone, Debug)]
pub enum BulbSelector {
All,
Id(BulbId),
@ -32,54 +24,48 @@ pub struct BulbConfig {
pub id: BulbId,
}
pub struct BulbManager {
config: BulbsConfig,
mqtt: MqttConfig,
command_rx: mpsc::Receiver<BulbCommand>,
socket: TcpStream,
state: Arc<BulbsState>,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BulbCommand {
SetPower(BulbSelector, bool),
SetColor(BulbSelector, BulbColor),
}
pub struct BulbsState {
#[derive(Clone)]
pub struct BulbManager {
state: Arc<BulbsState>,
}
struct BulbsState {
/// Notify on any change to the bulbs
notify: Notify,
/// Send commands to the background task
command: mpsc::Sender<BulbCommand>,
/// State of all bulbs
bulbs: RwLock<BTreeMap<BulbId, BulbMode>>,
}
enum Loop {
Break,
Continue,
}
#[derive(Deserialize)]
struct BulbResult {
#[serde(rename(deserialize = "POWER"))]
power: Option<String>,
#[serde(rename(deserialize = "Color"))]
color: Option<String>,
struct ManagerState<P> {
#[allow(dead_code)]
config: BulbsConfig,
provider: P,
command_rx: mpsc::Receiver<BulbCommand>,
state: Arc<BulbsState>,
}
impl BulbManager {
pub async fn launch(
bulbs: BulbsConfig,
mqtt: MqttConfig,
) -> anyhow::Result<(mpsc::Sender<BulbCommand>, Arc<BulbsState>)> {
info!("launching");
let socket = mqtt.connect().await?;
provider: impl BulbProvider + Send + 'static,
) -> eyre::Result<BulbManager> {
info!("launching BulbManager");
let (command_tx, command_rx) = mpsc::channel(100);
let bulbs_state = BulbsState {
let bulbs_state = Arc::new(BulbsState {
notify: Notify::new(),
command: command_tx,
bulbs: RwLock::new(
bulbs
.bulbs
@ -88,198 +74,84 @@ impl BulbManager {
.map(|id| (id, Default::default()))
.collect(),
),
};
let bulbs_state = Arc::new(bulbs_state);
});
let mut manager = BulbManager {
let state = ManagerState {
config: bulbs,
mqtt,
provider,
command_rx,
socket,
state: Arc::clone(&bulbs_state),
};
manager.subscribe().await?;
task::spawn(manager.run());
let manager = BulbManager { state: bulbs_state };
Ok((command_tx, bulbs_state))
task::spawn(run(state));
Ok(manager)
}
async fn subscribe(&mut self) -> anyhow::Result<()> {
let packet = SubscribePacket::new(
1,
vec![(TopicFilter::new("+/lampa/#")?, QualityOfService::Level0)],
);
let mut buf = vec![];
packet.encode(&mut buf)?;
self.socket.write_all(&buf).await?;
Ok(())
}
async fn run(mut self) -> anyhow::Result<()> {
loop {
match self.run_loop().await {
Ok(Loop::Continue) => {}
Ok(Loop::Break) => break,
Err(e) => {
const ERROR_TIMEOUT: u64 = 10;
error!("{e}");
loop {
info!("waiting for {ERROR_TIMEOUT} seconds before trying again");
sleep(Duration::from_secs(ERROR_TIMEOUT)).await;
match self.mqtt.connect().await {
Ok(new_socket) => {
self.socket = new_socket;
self.subscribe().await?;
break;
}
Err(e) => {
error!("failed to re-establish connections: {e}");
}
}
}
}
}
}
info!("exiting");
Ok(())
}
async fn run_loop(&mut self) -> anyhow::Result<Loop> {
let receive_packet = VariablePacket::parse(&mut self.socket);
let receive_command = self.command_rx.recv();
struct Publish<'a, P: ToString> {
pub topic_prefix: &'static str,
pub topic_suffix: &'static str,
pub payload: P,
pub socket: &'a mut TcpStream,
}
impl<P: ToString> Publish<'_, P> {
async fn send(&mut self, id: &BulbId) -> anyhow::Result<()> {
let topic_name = TopicName::new(format!(
"{}/{}/{}",
self.topic_prefix, id, self.topic_suffix
))?;
let qos = QoSWithPacketIdentifier::Level0;
let packet = PublishPacket::new(topic_name, qos, self.payload.to_string());
let mut buf = vec![];
packet.encode(&mut buf)?;
self.socket.write_all(&buf).await?;
anyhow::Ok(())
}
}
tokio::select!(
packet = receive_packet => {
debug!("packet received: {packet:?}");
match packet? {
VariablePacket::PublishPacket(publish) => {
let topic_name = publish.topic_name();
let topic_segments: Vec<&str> = topic_name.split('/').collect();
match &topic_segments[..] {
[prefix, id@.., suffix] => {
let id = BulbId(id.join("/"));
let mut bulbs = self.state.bulbs.write().await;
let bulb = match bulbs.get_mut(&id) {
None => {
warn!("unknown bulb: {id}");
return Ok(Loop::Continue);
}
Some(bulb) => bulb,
};
let payload = from_utf8(publish.payload())?;
match (*prefix, *suffix) {
("cmnd", _) => {}
("stat", "POWER") => {
bulb.power = payload == "ON";
}
("stat", "RESULT") => {
let result: BulbResult = serde_json::from_str(payload)?;
if let Some(power) = result.power {
bulb.power = power == "ON";
}
if let Some(color) = result.color {
bulb.color = color.parse()?;
}
}
("tele", "STATE") => {},
_ => {
warn!("unrecognized topic: {topic_name}");
return Ok(Loop::Continue);
}
}
}
_ => {
warn!("unrecognized topic: {topic_name}");
return Ok(Loop::Continue);
}
}
}
packet => warn!("unhandled packet: {packet:?}"),
}
self.state.notify.notify_waiters();
}
command = receive_command => {
info!("command received: {command:?}");
async fn send<P: ToString>(config: &BulbsConfig, selector: BulbSelector, publish: &mut Publish<'_, P>) -> anyhow::Result<()>{
match selector {
BulbSelector::All => {
for bulb in &config.bulbs {
publish.send(&bulb.id).await?;
}
}
BulbSelector::Id(id) =>publish.send(&id).await?,
}
Ok(())
}
match command {
Some(BulbCommand::SetPower(selector, power)) => {
let payload = if power { "ON" } else { "OFF" };
let mut publish = Publish {
topic_prefix: "cmnd",
topic_suffix: "POWER",
payload,
socket: &mut self.socket,
};
send(&self.config, selector, &mut publish).await?;
}
Some(BulbCommand::SetColor(selector, color)) => {
let mut publish = Publish {
topic_prefix: "cmnd",
topic_suffix: "COLOR",
payload: color.color_string(),
socket: &mut self.socket,
};
send(&self.config, selector, &mut publish).await?;
}
None => return Ok(Loop::Break),
}
}
);
Ok(Loop::Continue)
}
}
impl BulbsState {
pub fn notify_on_change(&self) -> Notified {
self.notify.notified()
self.state.notify.notified()
}
pub async fn send_command(&self, command: BulbCommand) {
info!("sending command {command:?}");
if let Err(e) = self.state.command.send(command).await {
error!("error sending bulb command: {e:#}");
}
info!("sent command");
}
pub async fn bulbs(&self) -> RwLockReadGuard<'_, BTreeMap<BulbId, BulbMode>> {
self.bulbs.read().await
self.state.bulbs.read().await
}
}
async fn run<P>(state: ManagerState<P>)
where
P: BulbProvider + Send,
{
debug!("manager task running");
if let Err(e) = run_inner(state).await {
error!("bulb manage exited with error: {e:#}");
}
info!("manager task exited");
}
async fn run_inner<P>(mut state: ManagerState<P>) -> eyre::Result<()>
where
P: BulbProvider + Send,
{
loop {
tokio::select!(
command = state.command_rx.recv() => {
let Some(command) = command else {
info!("handle closed, shutting down");
return Ok(());
};
info!("command received: {command:?}");
state.provider.send_command(command.clone()).await?;
}
update = state.provider.listen() => {
let (id, update) = update?;
info!("update received: {id:?} {update:?}");
let mut bulbs = state.state.bulbs.write().await;
let Some(bulb) = bulbs.get_mut(&id) else {
continue;
};
match update {
BulbUpdate::Power(power) => {
bulb.power = power;
}
BulbUpdate::Color(color) => {
bulb.color = color;
}
}
state.state.notify.notify_waiters();
}
)
}
}

View File

@ -17,7 +17,7 @@ pub struct MqttConfig {
}
impl MqttConfig {
pub async fn connect(&self) -> anyhow::Result<TcpStream> {
pub async fn connect(&self) -> eyre::Result<TcpStream> {
let mut socket =
TcpStream::connect((self.address.as_str(), self.port.unwrap_or(1883))).await?;
@ -35,9 +35,9 @@ impl MqttConfig {
match VariablePacket::parse(&mut socket).await? {
VariablePacket::ConnackPacket(ack) => match ack.connect_return_code() {
ConnectReturnCode::ConnectionAccepted => Ok(socket),
return_code => anyhow::bail!("connection refused: {return_code:?}"),
return_code => eyre::bail!("connection refused: {return_code:?}"),
},
response => anyhow::bail!("mqtt connect, unexpected response: {response:?}"),
response => eyre::bail!("mqtt connect, unexpected response: {response:?}"),
}
}
}

34
manager/src/provider.rs Normal file
View File

@ -0,0 +1,34 @@
use std::ops::DerefMut;
use crate::manager::BulbCommand;
use async_trait::async_trait;
use lighter_lib::{BulbColor, BulbId};
pub mod mock;
pub mod mqtt;
#[derive(Debug, Clone)]
pub enum BulbUpdate {
Power(bool),
Color(BulbColor),
}
// An interface that allows communication with bulbs.
#[async_trait]
pub trait BulbProvider {
// Send a command to some bulbs to update their state
async fn send_command(&mut self, cmd: BulbCommand) -> eyre::Result<()>;
// Wait for any bulb to send an update
async fn listen(&mut self) -> eyre::Result<(BulbId, BulbUpdate)>;
}
#[async_trait]
impl<P: BulbProvider + Send + Sync + ?Sized> BulbProvider for Box<P> {
async fn send_command(&mut self, cmd: BulbCommand) -> eyre::Result<()> {
self.deref_mut().send_command(cmd).await
}
async fn listen(&mut self) -> eyre::Result<(BulbId, BulbUpdate)> {
self.deref_mut().listen().await
}
}

View File

@ -0,0 +1,65 @@
use std::{collections::VecDeque, future::pending};
use async_trait::async_trait;
use lighter_lib::BulbId;
use crate::{
manager::{BulbCommand, BulbSelector, BulbsConfig},
provider::{BulbProvider, BulbUpdate},
};
pub struct BulbsMock {
config: BulbsConfig,
queue: VecDeque<(BulbId, BulbUpdate)>,
}
impl BulbsMock {
pub fn new(config: BulbsConfig) -> Self {
BulbsMock {
config,
queue: VecDeque::new(),
}
}
}
#[async_trait]
impl BulbProvider for BulbsMock {
async fn send_command(&mut self, command: BulbCommand) -> eyre::Result<()> {
info!("mock: sending command {command:?}");
let selector = match &command {
BulbCommand::SetPower(selector, _) => selector,
BulbCommand::SetColor(selector, _) => selector,
};
let bulbs = match selector {
BulbSelector::All => self.config.bulbs.iter().map(|b| b.id.clone()).collect(),
BulbSelector::Id(id) => vec![id.clone()],
};
let update = match command {
BulbCommand::SetPower(_, power) => BulbUpdate::Power(power),
BulbCommand::SetColor(_, color) => BulbUpdate::Color(color),
};
for bulb in bulbs {
info!("mock: updating bulb {bulb} {update:?}");
self.queue.push_back((bulb, update.clone()));
}
Ok(())
}
async fn listen(&mut self) -> eyre::Result<(BulbId, super::BulbUpdate)> {
info!("mock: listening for updates");
let Some(update) = self.queue.pop_front() else {
info!("mock: no updates in queue");
pending().await
};
info!("mock: returning update {update:?}");
Ok(update)
}
}

View File

@ -0,0 +1,234 @@
use std::{collections::HashSet, str, time::Duration};
use async_trait::async_trait;
use lighter_lib::BulbId;
use mqtt::{
packet::{PublishPacket, QoSWithPacketIdentifier, SubscribePacket, VariablePacket},
Encodable, QualityOfService, TopicFilter, TopicName,
};
use serde::Deserialize;
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
time::{sleep_until, Instant},
};
use crate::{
manager::{BulbCommand, BulbSelector, BulbsConfig},
mqtt_conf::MqttConfig,
provider::{BulbProvider, BulbUpdate},
};
const RECONNECT_DELAYS: &[u64] = &[0, 1, 2, 5, 10, 10, 10, 20];
pub struct BulbsMqtt {
known_bulbs: HashSet<BulbId>,
socket: SocketState,
}
struct SocketState {
mqtt_config: MqttConfig,
last_connection_attempt: Instant,
failed_connect_attempts: usize,
socket: Option<TcpStream>,
}
impl BulbsMqtt {
pub fn new(bulbs_config: BulbsConfig, mqtt_config: MqttConfig) -> BulbsMqtt {
BulbsMqtt {
known_bulbs: bulbs_config.bulbs.into_iter().map(|b| b.id).collect(),
socket: SocketState {
mqtt_config,
last_connection_attempt: Instant::now(),
failed_connect_attempts: 0,
socket: None,
},
}
}
}
impl SocketState {
async fn get_connection(&mut self) -> eyre::Result<&mut TcpStream> {
let socket = &mut self.socket;
if let Some(socket) = socket {
return Ok(socket);
}
let attempt = self.failed_connect_attempts;
let wait_for = Duration::from_secs(RECONNECT_DELAYS[attempt]);
sleep_until(self.last_connection_attempt + wait_for).await;
self.failed_connect_attempts += 1;
self.last_connection_attempt = Instant::now();
info!("connecting to MQTT (attempt {attempt})");
let mut new_socket = self.mqtt_config.connect().await?;
subscribe(&mut new_socket).await?;
info!("connected to MQTT");
self.failed_connect_attempts = 0;
Ok(socket.insert(new_socket))
}
}
#[async_trait]
impl BulbProvider for BulbsMqtt {
async fn send_command(&mut self, command: BulbCommand) -> eyre::Result<()> {
debug!("mqtt sending command {command:?}");
let socket = self.socket.get_connection().await?;
async fn send<P: ToString>(
all_bulbs: &HashSet<BulbId>,
selector: BulbSelector,
publish: &mut Publish<'_, P>,
) -> eyre::Result<()> {
match selector {
BulbSelector::Id(id) => publish.send(&id).await?,
BulbSelector::All => {
for id in all_bulbs {
publish.send(id).await?;
}
}
}
Ok(())
}
match command {
BulbCommand::SetPower(selector, power) => {
let payload = if power { "ON" } else { "OFF" };
let mut publish = Publish {
topic_prefix: "cmnd",
topic_suffix: "POWER",
payload,
socket,
};
send(&self.known_bulbs, selector, &mut publish).await?;
}
BulbCommand::SetColor(selector, color) => {
let mut publish = Publish {
topic_prefix: "cmnd",
topic_suffix: "COLOR",
payload: color.color_string(),
socket,
};
send(&self.known_bulbs, selector, &mut publish).await?;
}
}
Ok(())
}
async fn listen(&mut self) -> eyre::Result<(BulbId, BulbUpdate)> {
debug!("mqtt listening for updates");
let socket = self.socket.get_connection().await?;
loop {
let packet = VariablePacket::parse(socket).await?;
let VariablePacket::PublishPacket(publish) = &packet else {
continue;
};
let topic_name = publish.topic_name();
let topic_segments: Vec<&str> = topic_name.split('/').collect();
match &topic_segments[..] {
[prefix, id @ .., suffix] => {
let id = BulbId(id.join("/"));
if !self.known_bulbs.contains(&id) {
warn!("ignoring publish from unknown bulb {id}");
continue;
}
let payload = str::from_utf8(publish.payload())?;
let update = match (*prefix, *suffix) {
("stat", "POWER") => BulbUpdate::Power(payload == "ON"),
("stat", "RESULT") => {
let result: BulbResult = serde_json::from_str(payload)?;
// TODO: color and power can be updated at the same time?
if let Some(color) = result.color {
BulbUpdate::Color(color.parse()?)
} else if let Some(power) = result.power {
BulbUpdate::Power(power == "ON")
} else {
continue;
}
}
// TODO: handle STATE message
//("tele", "STATE") => todo!(),
// ignore known useless messages
("cmnd", _) => continue,
("tele", "LWT") => continue,
_ => {
warn!("unrecognized topic: {topic_name} payload={payload:?}");
continue;
}
};
return Ok((id, update));
}
_ => {
warn!("unrecognized topic: {topic_name}");
continue;
}
}
}
}
}
struct Publish<'a, P: ToString> {
pub topic_prefix: &'static str,
pub topic_suffix: &'static str,
pub payload: P,
pub socket: &'a mut TcpStream,
}
impl<P: ToString> Publish<'_, P> {
async fn send(&mut self, id: &BulbId) -> eyre::Result<()> {
let topic_name = TopicName::new(format!(
"{}/{}/{}",
self.topic_prefix, id, self.topic_suffix
))?;
let qos = QoSWithPacketIdentifier::Level0;
let payload = self.payload.to_string();
debug!("publishing {topic_name:?} {payload}");
let packet = PublishPacket::new(topic_name, qos, payload);
let mut buf = vec![];
packet.encode(&mut buf)?;
self.socket.write_all(&buf).await?;
debug!("published");
Ok(())
}
}
/// Send the MQTT subscribe packet to subscribe to bulb updates
async fn subscribe(socket: &mut TcpStream) -> eyre::Result<()> {
debug!("subscribing to mqtt bulb updates");
let packet = SubscribePacket::new(
1,
vec![(TopicFilter::new("+/lampa/#")?, QualityOfService::Level0)],
);
let mut buf = vec![];
packet.encode(&mut buf)?;
socket.write_all(&buf).await?;
debug!("subscribed to mqtt bulb updates");
Ok(())
}
#[derive(Deserialize)]
struct BulbResult {
#[serde(rename(deserialize = "POWER"))]
power: Option<String>,
#[serde(rename(deserialize = "Color"))]
color: Option<String>,
}