001/*******************************************************************************
002 * Copyright (c) 2017 Red Hat Inc and others.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *     Red Hat Inc - initial API and implementation
011 *******************************************************************************/
012package org.eclipse.kapua.gateway.client.mqtt.paho;
013
014import static java.util.Objects.requireNonNull;
015import static org.eclipse.kapua.gateway.client.utils.Strings.nonEmptyText;
016
017import java.net.URI;
018import java.nio.ByteBuffer;
019import java.util.ArrayList;
020import java.util.HashMap;
021import java.util.List;
022import java.util.Map;
023import java.util.Set;
024import java.util.concurrent.CompletableFuture;
025import java.util.concurrent.CompletionStage;
026import java.util.concurrent.Executors;
027import java.util.concurrent.ScheduledExecutorService;
028import java.util.concurrent.TimeUnit;
029import java.util.function.Supplier;
030
031import org.eclipse.kapua.gateway.client.BinaryPayloadCodec;
032import org.eclipse.kapua.gateway.client.Credentials.UserAndPassword;
033import org.eclipse.kapua.gateway.client.Module;
034import org.eclipse.kapua.gateway.client.mqtt.MqttClient;
035import org.eclipse.kapua.gateway.client.mqtt.MqttMessageHandler;
036import org.eclipse.kapua.gateway.client.mqtt.MqttNamespace;
037import org.eclipse.kapua.gateway.client.mqtt.paho.internal.Listeners;
038import org.eclipse.kapua.gateway.client.utils.Buffers;
039import org.eclipse.paho.client.mqttv3.IMqttActionListener;
040import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken;
041import org.eclipse.paho.client.mqttv3.IMqttToken;
042import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
043import org.eclipse.paho.client.mqttv3.MqttCallback;
044import org.eclipse.paho.client.mqttv3.MqttClientPersistence;
045import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
046import org.eclipse.paho.client.mqttv3.MqttException;
047import org.eclipse.paho.client.mqttv3.MqttMessage;
048import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052public class PahoClient extends MqttClient {
053
054    private static final Logger logger = LoggerFactory.getLogger(PahoClient.class);
055
056    public static class Builder extends MqttClient.Builder<Builder> {
057
058        private Supplier<MqttClientPersistence> persistenceProvider = MemoryPersistence::new;
059
060        @Override
061        protected Builder builder() {
062            return this;
063        }
064
065        public Builder persistentProvider(final Supplier<MqttClientPersistence> provider) {
066            if (provider != null) {
067                this.persistenceProvider = provider;
068            } else {
069                this.persistenceProvider = MemoryPersistence::new;
070            }
071            return builder();
072        }
073
074        public Supplier<MqttClientPersistence> persistentProvider() {
075            return this.persistenceProvider;
076        }
077
078        @Override
079        public PahoClient build() throws Exception {
080
081            final URI broker = requireNonNull(broker(), "Broker must be set");
082            final String clientId = nonEmptyText(clientId(), "clientId");
083
084            final MqttClientPersistence persistence = requireNonNull(this.persistenceProvider.get(), "Persistence provider returned 'null' persistence");
085            final MqttNamespace namespace = requireNonNull(namespace(), "Namespace must be set");
086            final BinaryPayloadCodec codec = requireNonNull(codec(), "Codec must be set");
087
088            MqttAsyncClient client = new MqttAsyncClient(broker.toString(), clientId, persistence);
089            ScheduledExecutorService executor = createExecutor(clientId);
090            try {
091                final PahoClient result = new PahoClient(modules(), clientId, executor, namespace, codec, client, persistence, createConnectOptions(this));
092                client = null;
093                executor = null;
094                return result;
095            } finally {
096                if (executor != null) {
097                    executor.shutdown();
098                }
099                if (client != null) {
100                    try {
101                        client.disconnectForcibly(0);
102                    } finally {
103                        client.close();
104                    }
105                }
106            }
107        }
108    }
109
110    private static ScheduledExecutorService createExecutor(final String clientId) {
111        return Executors.newSingleThreadScheduledExecutor(r -> new Thread(r, clientId));
112    }
113
114    private static MqttConnectOptions createConnectOptions(final Builder builder) {
115        final MqttConnectOptions result = new MqttConnectOptions();
116
117        final Object credentials = builder.credentials();
118        if (credentials instanceof UserAndPassword) {
119            final UserAndPassword userAndPassword = (UserAndPassword) credentials;
120            result.setUserName(userAndPassword.getUsername());
121            result.setPassword(userAndPassword.getPassword());
122        } else if (credentials == null) {
123            // ignore
124        } else {
125            throw new IllegalArgumentException(String.format("Unsupported credentials type: %s", credentials.getClass().getName()));
126        }
127
128        return result;
129    }
130
131    private final MqttConnectOptions connectOptions;
132    private MqttAsyncClient client;
133
134    private final Map<String, MqttMessageHandler> subscriptions = new HashMap<>();
135
136    private PahoClient(final Set<Module> modules, final String clientId, final ScheduledExecutorService executor, final MqttNamespace namespace, final BinaryPayloadCodec codec,
137            final MqttAsyncClient client, final MqttClientPersistence persistence, final MqttConnectOptions connectOptions) {
138
139        super(executor, codec, namespace, clientId, modules);
140
141        this.connectOptions = connectOptions;
142        this.client = client;
143
144        this.client.setCallback(new MqttCallback() {
145
146            @Override
147            public void messageArrived(final String topic, final MqttMessage message) throws Exception {
148                handleMessageArrived(topic, message);
149            }
150
151            @Override
152            public void deliveryComplete(final IMqttDeliveryToken token) {
153            }
154
155            @Override
156            public void connectionLost(final Throwable cause) {
157                handleDisconnected();
158            }
159        });
160
161        this.executor.execute(this::connect);
162    }
163
164    protected void connect() {
165        try {
166            this.client.connect(this.connectOptions, null, new IMqttActionListener() {
167
168                @Override
169                public void onSuccess(final IMqttToken asyncActionToken) {
170                    handleConnected();
171                }
172
173                @Override
174                public void onFailure(final IMqttToken asyncActionToken, final Throwable exception) {
175                    handleDisconnected();
176                }
177            });
178        } catch (final MqttException e) {
179            logger.warn("Failed to call connect", e);
180        }
181    }
182
183    @Override
184    public void close() {
185
186        final MqttAsyncClient client;
187
188        synchronized (this) {
189            client = this.client;
190            if (client == null) {
191                return;
192            }
193            this.client = null;
194        }
195
196        try {
197            // disconnect first
198
199            try {
200                client.disconnect().waitForCompletion();
201            } catch (final MqttException e) {
202            }
203
204            // now try to close (and free the resources)
205
206            try {
207                client.close();
208            } catch (final MqttException e) {
209            }
210        } finally {
211            this.executor.shutdown();
212        }
213    }
214
215    protected void handleConnected() {
216        synchronized (this) {
217            super.handleConnected();
218            handleResubscribe();
219        }
220    }
221
222    private void handleResubscribe() {
223        for (final Map.Entry<String, MqttMessageHandler> entry : this.subscriptions.entrySet()) {
224            try {
225                internalSubscribe(entry.getKey());
226            } catch (final MqttException e) {
227                logger.warn("Failed to re-subscribe to '{}'", entry.getKey());
228            }
229        }
230    }
231
232    protected void handleDisconnected() {
233        synchronized (this) {
234            try {
235                super.handleDisconnected();
236            } finally {
237                this.executor.schedule(this::connect, 1, TimeUnit.SECONDS);
238            }
239        }
240    }
241
242    @Override
243    public void publishMqtt(final String topic, final ByteBuffer payload) throws Exception {
244        publish(topic, payload);
245    }
246
247    protected void publish(final String topic, final ByteBuffer payload) throws MqttException {
248        logger.debug("Publishing {} - {}", topic, payload);
249        this.client.publish(topic, Buffers.toByteArray(payload), 1, false);
250    }
251
252    @Override
253    protected CompletionStage<?> subscribeMqtt(String topic, MqttMessageHandler messageHandler) throws MqttException {
254        synchronized (this) {
255            this.subscriptions.put(topic, messageHandler);
256            return internalSubscribe(topic);
257        }
258    }
259
260    @Override
261    protected void unsubscribeMqtt(final Set<String> mqttTopics) throws MqttException {
262        logger.info("Unsubscribe from: {}", mqttTopics);
263
264        final List<String> topics = new ArrayList<>(mqttTopics.size());
265
266        synchronized (this) {
267            for (String topic : mqttTopics) {
268                if (subscriptions.remove(topic) != null) {
269                    topics.add(topic);
270                }
271            }
272        }
273
274        client.unsubscribe(topics.toArray(new String[topics.size()]));
275    }
276
277    protected void handleMessageArrived(final String topic, final MqttMessage message) throws Exception {
278        final ByteBuffer buffer = Buffers.wrap(message.getPayload());
279        buffer.flip();
280
281        logger.debug("Received message - mqtt-topic: {}, payload: {}", topic, buffer);
282
283        final MqttMessageHandler handler = this.subscriptions.get(topic);
284        if (handler != null) {
285            handler.handleMessage(topic, buffer);
286        }
287    }
288
289    private CompletionStage<?> internalSubscribe(final String topic) throws MqttException {
290        final CompletableFuture<?> future = new CompletableFuture<>();
291        this.client.subscribe(topic, 1, null, Listeners.toListener(future));
292        return future;
293    }
294
295}