LRU Cache - Advanced Data Structures / Data Structure Design

https://algo.monster/problems/lru_cache

The solution might be more readable if we move two common operations into helper functions:

class LRU:
    
    def __init__(self,capacity):
        self.m = {}
        self.head = DLL(0,0)
        self.tail = DLL(0,0)
        self.head.next = self.tail
        self.tail.prev = self.head
        self.size = 0
        self.capacity = capacity
        
    def get(self,key):
        if key in self.m:
            node = self.m[key]
            self._bubble(node)
            return node.val
        else:
            return -1
    
    def _stitch(self,n1,n2):
        if n1: n1.next = n2
        if n2: n2.prev = n1
    
    def _bubble(self,node):
        self._stitch(node.prev,node.next)
        cur_top = self.head.next
        self._stitch(self.head,node)
        self._stitch(node,cur_top)
    
    def put(self,key,val):
        if key in self.m:
            self.m[key].val = val
            self._bubble(self.m[key])
        else:
            self.size +=1
            if self.size > self.capacity:
                lru = self.tail.prev
                self._stitch(lru.prev,self.tail)
                del self.m[lru.key]
                self.size -=1
            new_node = DLL(key,val)
            self._bubble(new_node)
            self.m[key] = new_node
            

Simple C++ LRU implementation

using namespace std;

using Pair = pair<int, int>;
class LRUCache{
    unordered_map<int, deque<Pair>::iterator> cache;
    deque<Pair> dq;
    int capacity;
    public:
       LRUCache(int size) {
           capacity = size;
       }
    
        void _resize(){
            if(dq.size() > capacity){
                Pair last = dq.back();
                dq.pop_back();
                cache.erase(last.first);
            }
        }
    
        bool put(int key, int value){
            if(cache.find(key) == cache.end()){
                dq.push_front(Pair(key, value));
                cache[key] = dq.begin();
                _resize();
                return true;
            }
            return false;
        }
        
        int get(int key) {
            if(cache.find(key) != cache.end()) {
                deque<Pair>::iterator dit = cache[key];
                int value = (*dit).second;
                
                dq.erase(dit);
                dq.push_front(Pair(key, value));
                cache[key] = dq.begin();
                
                return value;
            }
            return -1;
        }
    
};

Here’s my implementation. I just break the code up into helper functions so its a lot easier to understand.

public static class LRUCache {
public static class Node {
int key;
int val;
Node next;
Node prev;

        public Node(int key, int val) {
            this.key = key;
            this.val = val;
        }
    }
    
    final int capacity;
    final HashMap<Integer, Node> cache;
    
    public int size;
    public Node head;
    public Node tail;
    
    public LRUCache(int capacity) {
        this.capacity = capacity;
        this.cache = new HashMap<>(capacity);
        
        // Update head and tail
        this.head = new Node(0, 0);
        this.tail = new Node(0, 0);
        this.head.next = tail;
        this.tail.prev = head;
        
        this.size = 0;
    }
    
    private void remove(Node node) {
        node.prev.next = node.next;
        node.next.prev = node.prev;
    }
    
    private void insertToHead(Node node) {
        this.head.next.prev = node;
        node.next = this.head.next;
        this.head.next = node;
        node.prev = this.head;
    }
    
    private void updateMostRecent(int key) {
        Node node = this.cache.get(key);
        remove(node);
        insertToHead(node);
    }
    
    public Integer get(int key) {
        if (!this.cache.containsKey(key)) {
            return -1;
        }
        
        updateMostRecent(key);
        return this.cache.get(key).val;
    }
    
    public void put(int key, int value) {
        if (this.cache.containsKey(key)) {
            updateMostRecent(key);
            this.cache.get(key).val = value;
            return;
        }
        
        if (size + 1 > capacity) {
            Node lru = this.tail.prev;
            remove(lru);
            this.cache.remove(lru.key);
        }
        
        this.size++;
        Node newNode = new Node(key, value);
        insertToHead(newNode);
        this.cache.put(key, newNode);
    }
}

https://leetcode.com/problems/lru-cache/solution/

This part of the C++ solution is wrong. It causes a memory leak. With small tests the leak is not noticed.

        if (size > capacity) {
            DLL* lru = tail->prev;
            cache.erase(lru->key);
            tail->prev->val = tail->val;
            tail->prev->next = nullptr;
            tail = tail->prev;
            size--;
        }

The correct code for that part is below:

        if (size > capacity) {
            DLL* lru = tail->prev;
            cache.erase(lru->key);
            tail->prev = lru->prev;
            lru->prev->next = tail;
            delete lru;
            size--;
        }

My solution might be more readable. I use to helper functions of delete() and insert() which delete the node and insert right after the head respectively.

Pro tip: Always instantiate prev_node and next_node as I have done in my _delete() and _insert() functions to keep the logic clear.

from typing import List

class Node:
    def __init__(self, key = None, val=None, next_node=None, prev_node=None):
        self.next = next_node
        self.prev = prev_node
        self.val = val
        self.key = key
        
        
class LruCache:
    def __init__(self, size):
        self.size = int(size)
        self.cache = {}
        self.head = Node()
        self.tail = Node()
        
        # connect the head and the tail as there are no other nodes in the beginning
        self.head.next = self.tail
        self.tail.prev = self.head
        
    def get(self, key):
        """Gets the value of the node"""
        if key not in self.cache:
            return -1
        
        node = self.cache[key]
        self._delete(node)
        self._insert(node)
        
        return node.val
    
    def _delete(self, node):
        """Removes the node from the doubly linked list."""
        prev_node = node.prev
        next_node = node.next
        
        # connect the prev with the next
        prev_node.next = next_node
        next_node.prev = prev_node
        
    def _insert(self, node):
        """Insert the node after head making it MRU."""
        prev_node = self.head
        next_node = self.head.next
        
        # put node after head (prev_node) and before next_node
        prev_node.next = node
        node.next = next_node
        next_node.prev = node
        node.prev = prev_node
        
    def put(self, key, val):
        """Add key with the value and update the priority list"""

        if key in self.cache:
            # key already exists, so just update the MRU and the value in the cache
            self.cache[key].val = val
            
            # remove the node and add it after the head
            self._delete(self.cache[key])
            self._insert(self.cache[key])
        
        else:
            # key does not exist, then check size and make room by deleting LRU (last node before tail)
            if len(self.cache.keys()) == self.size:
                
                # delete the key first as the self.tail.prev.key will change after self._delete()
                del self.cache[self.tail.prev.key]
                self._delete(self.tail.prev)
                
            node = Node(key, val)
            self._insert(node)
            self.cache[key] = node