ConcurrentHashMap not inserting the non null key value pair upon calling the put method

416 Views Asked by At

I am using a ConcurrentHashMap to cache task that I am processing on a SocketChannel. The StreamTask is a Runnable that is use to reschedules itself if the round trip threshold is elapse during client server communication, therefore it will remove itself from the cache if time elapse. Furthermore, StreamWriteTask thread will put it on the cache and the StreaReadTask will attempt to remove it.

The problem is that when I invoke the 'processingCache.put()' it does not always add to the map.

public class ClientServerTest {

    private class StreamTask implements Runnable {
        private final String taskIdentifier;
        private byte[] data;
        private int scheduleAttempts = 1;
        private long startTime;
        private Runnable future;

        private static final long ROND_TRIP_THRESHOLD = 15000L;
        private static final int MAX_SCHEDULE_ATTEMPTS = 3;

        public StreamTask(String taskIdentifier, byte[] data) {
            super();
            this.taskIdentifier = taskIdentifier;
            this.data = data;
        }

        @Override
        public void run() {
            if (scheduleAttempts < MAX_SCHEDULE_ATTEMPTS) {
                StreamTask task = null;
                processingCacheLock.writeLock().lock(); 
                try{
                    task = processingCache.remove(taskIdentifier);
                }finally{
                    processingCacheLock.writeLock().unlock();
                }

                if (task == null) {
                    return;
                }

                scheduleStreamTask(task);
                scheduleAttempts++;
            } else {
                failedTasks.add(this);
            }

        }

        @Override
        public int hashCode() {
            return taskIdentifier == null ? 0 : super.hashCode();
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == null) {
                return false;
            }

            if (!(obj instanceof StreamTask)) {
                return false;
            }
            StreamTask task = (StreamTask) obj;
            boolean equals = false;
            if (this.taskIdentifier != null
                    && this.taskIdentifier.equals(task.taskIdentifier)) {
                equals = true;
            }

            if (this.hashCode() == task.hashCode()) {
                equals = true;
            }

            return equals;
        }

    }

    private class StreamWriteTask implements Runnable {
        private ByteBuffer buffer;
        private SelectionKey key;

        private StreamWriteTask(ByteBuffer buffer, SelectionKey key) {
            this.buffer = buffer;
            this.key = key;
        }

        private byte[] getData() {
            byte[] data;
            if (key.attachment() != null) {
                data = (byte[]) key.attachment();
                System.out.println("StreamWriteTask continuation.....");
            } else {
                StreamTask task = getStreamTask();
                if (task == null) {
                    return null;
                }
                System.out.println("Processing New Task ~~~~~ "
                        + task.taskIdentifier);
                processingCacheLock.readLock().lock();
                try {
                    task = processingCache.put(task.taskIdentifier, task);
                    boolean cached = processingCache.containsKey(task.taskIdentifier);
                    System.out.println("Has task been cached? " + cached);
                } finally {
                    processingCacheLock.readLock().unlock();
                }

                task.startTime = System.currentTimeMillis();
                data = task.data;
            }

            return data;
        }

        @Override
        public void run() {
            byte[] data = getData();
            if (data != null) {
                SocketChannel sc = (SocketChannel) key.channel();
                buffer.clear();
                buffer.put(data);
                buffer.flip();
                int results = 0;
                while (buffer.hasRemaining()) {
                    try {
                        results = sc.write(buffer);
                    } catch (IOException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }

                    if (results == 0) {
                        buffer.compact();
                        buffer.flip();
                        data = new byte[buffer.remaining()];
                        buffer.get(data);
                        key.interestOps(SelectionKey.OP_WRITE);
                        key.attach(data);
                        System.out
                                .println("Partial write to socket channel....");
                        selector.wakeup();
                        return;
                    }
                }
            }

            System.out
                    .println("Write to socket channel complete for client...");
            key.interestOps(SelectionKey.OP_READ);
            key.attach(null);
            returnBuffer(buffer);
            selector.wakeup();
        }

    }

    private class StreamReadTask implements Runnable {
        private ByteBuffer buffer;
        private SelectionKey key;

        private StreamReadTask(ByteBuffer buffer, SelectionKey key) {
            this.buffer = buffer;
            this.key = key;
        }

        @Override
        public void run() {
            long endTime = System.currentTimeMillis();
            SocketChannel sc = (SocketChannel) key.channel();
            buffer.clear();
            byte[] data = (byte[]) key.attachment();
            if (data != null) {
                buffer.put(data);
            }
            int count = 0;
            int readAttempts = 0;
            try {
                while ((count = sc.read(buffer)) > 0) {
                    readAttempts++;
                }
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }

            if (count == 0) {
                buffer.flip();
                data = new byte[buffer.limit()];
                buffer.get(data);
                String uuid = new String(data);
                System.out.println("Client Read - uuid ~~~~ " + uuid);
                boolean success = finalizeStreamTask(uuid, endTime);
                key.interestOps(SelectionKey.OP_WRITE);
                key.attach(null);
                System.out.println("Did task finalize correctly ~~~~ "
                        + success);
            }

            if (count == -1) {
                try {
                    sc.close();
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }

            returnBuffer(buffer);
            selector.wakeup();
        }

    }

    private class ClientWorker implements Runnable {

        @Override
        public void run() {
            try {

                while (selector.isOpen()) {
                    int count = selector.select(500);

                    if (count == 0) {
                        continue;
                    }

                    Iterator<SelectionKey> it = selector.selectedKeys()
                            .iterator();

                    while (it.hasNext()) {
                        final SelectionKey key = it.next();
                        it.remove();
                        if (!key.isValid()) {
                            continue;
                        }

                        if (key.isConnectable()) {
                            SocketChannel sc = (SocketChannel) key.channel();
                            if (!sc.finishConnect()) {
                                continue;
                            }
                            sc.register(selector, SelectionKey.OP_WRITE);
                        }

                        if (key.isReadable()) {
                            ByteBuffer buffer = borrowBuffer();
                            if (buffer != null) {
                                key.interestOps(0);
                                executor.execute(new StreamReadTask(buffer, key));
                            }
                        }
                        if (key.isWritable()) {
                            ByteBuffer buffer = borrowBuffer();
                            if (buffer != null) {
                                key.interestOps(0);
                                executor.execute(new StreamWriteTask(buffer,
                                        key));
                            }
                        }
                    }
                }
            } catch (IOException ex) {
                // Handle Exception
            }

        }
    }

    private class ServerWorker implements Runnable {
        @Override
        public void run() {
            try {
                Selector selector = Selector.open();
                ServerSocketChannel ssc = ServerSocketChannel.open();
                ServerSocket socket = ssc.socket();
                socket.bind(new InetSocketAddress(9001));
                ssc.configureBlocking(false);
                ssc.register(selector, SelectionKey.OP_ACCEPT);
                ByteBuffer buffer = ByteBuffer.allocateDirect(65535);
                DataHandler handler = new DataHandler();

                while (selector.isOpen()) {
                    int count = selector.select(500);

                    if (count == 0) {
                        continue;
                    }

                    Iterator<SelectionKey> it = selector.selectedKeys()
                            .iterator();

                    while (it.hasNext()) {
                        final SelectionKey key = it.next();
                        it.remove();
                        if (!key.isValid()) {
                            continue;
                        }

                        if (key.isAcceptable()) {
                            ssc = (ServerSocketChannel) key.channel();
                            SocketChannel sc = ssc.accept();
                            sc.configureBlocking(false);
                            sc.register(selector, SelectionKey.OP_READ);
                        }
                        if (key.isReadable()) {
                            handler.readSocket(buffer, key);
                        }
                        if (key.isWritable()) {
                            handler.writeToSocket(buffer, key);
                        }
                    }
                }

            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }

    }

    private class DataHandler {

        private JsonObject parseData(StringBuilder builder) {
            if (!builder.toString().endsWith("}")) {
                return null;
            }

            JsonParser parser = new JsonParser();
            JsonObject obj = (JsonObject) parser.parse(builder.toString());
            return obj;
        }

        private void readSocket(ByteBuffer buffer, SelectionKey key)
                throws IOException {
            SocketChannel sc = (SocketChannel) key.channel();
            buffer.clear();
            int count = Integer.MAX_VALUE;
            int readAttempts = 0;
            try {
                while ((count = sc.read(buffer)) > 0) {
                    readAttempts++;
                }
            } catch (IOException e) {
                e.printStackTrace();
            }

            if (count == 0) {
                buffer.flip();
                StringBuilder builder = key.attachment() instanceof StringBuilder ? (StringBuilder) key
                        .attachment() : new StringBuilder();
                Charset charset = Charset.forName("UTF-8");
                CharsetDecoder decoder = charset.newDecoder();
                decoder.onMalformedInput(CodingErrorAction.IGNORE);
                CharBuffer charBuffer = decoder.decode(buffer);
                String content = charBuffer.toString();
                charBuffer = null;
                builder.append(content);
                JsonObject obj = parseData(builder);
                if (obj == null) {
                    // System.out.println("Server processed partial read for task");
                    key.attach(builder);
                    key.interestOps(SelectionKey.OP_READ);
                } else {
                    JsonPrimitive uuid = obj.get("uuid").getAsJsonPrimitive();
                    System.out
                            .println("Server read complete for task  ~~~~~~~ "
                                    + uuid);
                    key.attach(uuid.toString().getBytes());
                    key.interestOps(SelectionKey.OP_WRITE);
                }
            }

            if (count == -1) {
                key.attach(null);
                sc.close();
            }
        }

        private void writeToSocket(ByteBuffer buffer, SelectionKey key)
                throws IOException {
            SocketChannel sc = (SocketChannel) key.channel();
            byte[] data = (byte[]) key.attachment();
            buffer.clear();
            buffer.put(data);
            buffer.flip();
            int writeAttempts = 0;
            while (buffer.hasRemaining()) {
                int results = sc.write(buffer);
                writeAttempts++;
                // System.out.println("Write Attempt #" + writeAttempts);
                if (results == 0) {
                    System.out.println("Server process partial write....");
                    buffer.compact();
                    buffer.flip();
                    data = new byte[buffer.remaining()];
                    buffer.get(data);
                    key.attach(data);
                    key.interestOps(SelectionKey.OP_WRITE);
                    return;
                }
            }

            System.out.println("Server write complete for task ~~~~~ "
                    + new String(data));
            key.interestOps(SelectionKey.OP_READ);
            key.attach(null);
        }
    }

    public ClientServerTest() throws IOException {
        selector = Selector.open();
        processingCache = new ConcurrentHashMap<String, StreamTask>(
                MAX_DATA_LOAD, 2);
        for (int index = 0; index < MAX_DATA_LOAD; index++) {
            JsonObject obj = new JsonObject();
            String uuid = UUID.randomUUID().toString();
            obj.addProperty("uuid", uuid);
            String data = RandomStringUtils.randomAlphanumeric(12800000);
            obj.addProperty("event", data);
            StreamTask task = new StreamTask(uuid, obj.toString().getBytes());
            taskQueue.add(task);
        }

        for (int index = 0; index < CLIENT_SOCKET_CONNECTIONS; index++) {
            ByteBuffer bf = ByteBuffer.allocate(2 << 23);
            bufferQueue.add(bf);
            SocketChannel sc = SocketChannel.open();
            sc.configureBlocking(false);
            sc.connect(new InetSocketAddress("127.0.0.1", 9001));
            sc.register(selector, SelectionKey.OP_CONNECT);
        }

        Thread serverWorker = new Thread(new ServerWorker());
        serverWorker.start();

        Thread clientWorker = new Thread(new ClientWorker());
        clientWorker.start();

    }

    private void start() {
        long startTime = System.currentTimeMillis();
        for (;;) {
            if (taskQueue.isEmpty() && processingCache.isEmpty()) {
                long endTime = System.currentTimeMillis();
                System.out.println("Overall Processing time ~~~~ "
                        + (endTime - startTime) + "ms");
                break;
            }
        }
    }

    private ByteBuffer borrowBuffer() {
        ByteBuffer buffer = null;

        try {
            buffer = bufferQueue.poll(5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return buffer;
    }

    private boolean returnBuffer(ByteBuffer buffer) {
        boolean success = true;
        try {
            buffer.clear();
            bufferQueue.offer(buffer, 5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            success = false;
            e.printStackTrace();
        }
        return success;
    }

    private StreamTask getStreamTask() {
        StreamTask task = null;
        taskQueueAddLock.lock();
        try {
            task = taskQueue.take();
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } finally {
            taskQueueAddLock.unlock();
        }

        return task;
    }

    private boolean scheduleStreamTask(StreamTask task) {
        boolean success = true;
        taskQueueRemoveLock.lock();
        try {
            taskQueue.offer(task, 5000L, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            success = false;
            e.printStackTrace();
        } finally {
            taskQueueRemoveLock.unlock();
        }

        return success;
    }

    private boolean finalizeStreamTask(String uuid, long endTime) {
        boolean success = true;
        StreamTask task;
        processingCacheLock.writeLock().lock();
        try {
            task = processingCache.remove(uuid);
            success = task != null;
        } finally {
            processingCacheLock.writeLock().unlock();
        }

        if (success) {
            success = executor.remove(task.future);
            executor.purge();
        }

        if (!success) {
            taskQueueAddLock.lock();
            taskQueueRemoveLock.lock();
            try {
                Iterator<StreamTask> it = taskQueue.iterator();
                while (it.hasNext()) {
                    task = it.next();
                    if (task.taskIdentifier == uuid) {
                        it.remove();
                        success = true;
                    }
                }
            } finally {
                taskQueueAddLock.unlock();
                taskQueueRemoveLock.unlock();
            }
            success = !taskQueue.contains(task);
        }

        System.out.println("Processing time ~~~~~~ "
                + (endTime - task.startTime) + "ms");
        return success;
    }

    /**
     * @param args
     */
    public static void main(String[] args) {
        try {
            ClientServerTest test = new ClientServerTest();
            test.start();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private static final int CLIENT_SOCKET_CONNECTIONS = 1;
    private static final int MAX_DATA_LOAD = 2;

    private volatile ConcurrentHashMap<String, StreamTask> processingCache;
    private volatile LinkedBlockingQueue<StreamTask> taskQueue = new LinkedBlockingQueue<StreamTask>();
    private volatile ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(
            CLIENT_SOCKET_CONNECTIONS);
    private volatile LinkedBlockingQueue<ByteBuffer> bufferQueue = new LinkedBlockingQueue<ByteBuffer>();
    private volatile List<StreamTask> failedTasks = new ArrayList<StreamTask>();
    private volatile Selector selector;
    private final ReentrantLock taskQueueAddLock = new ReentrantLock();
    private final ReentrantLock taskQueueRemoveLock = new ReentrantLock();
    private final ReentrantReadWriteLock processingCacheLock = new ReentrantReadWriteLock();
}
1

There are 1 best solutions below

3
On

Your problem is probably a misunderstanding of what put() returns. After this line:

task = processingCache.put(task.taskIdentifier, task);

task is equal to the previous value stored in the map for that key if any or null otherwise. If the map did not have a key task.taskIdentifier before that call, then put() returns null and the next line:

boolean cached = processingCache.containsKey(task.taskIdentifier);

will throw a NullPointerException.

From ConcurrenMap#put javadoc (emphasis mine):

Returns the previous value associated with key, or null if there was no mapping for key