The solution might be more readable if we move two common operations into helper functions:
class LRU:
def __init__(self,capacity):
self.m = {}
self.head = DLL(0,0)
self.tail = DLL(0,0)
self.head.next = self.tail
self.tail.prev = self.head
self.size = 0
self.capacity = capacity
def get(self,key):
if key in self.m:
node = self.m[key]
self._bubble(node)
return node.val
else:
return -1
def _stitch(self,n1,n2):
if n1: n1.next = n2
if n2: n2.prev = n1
def _bubble(self,node):
self._stitch(node.prev,node.next)
cur_top = self.head.next
self._stitch(self.head,node)
self._stitch(node,cur_top)
def put(self,key,val):
if key in self.m:
self.m[key].val = val
self._bubble(self.m[key])
else:
self.size +=1
if self.size > self.capacity:
lru = self.tail.prev
self._stitch(lru.prev,self.tail)
del self.m[lru.key]
self.size -=1
new_node = DLL(key,val)
self._bubble(new_node)
self.m[key] = new_node
Simple C++ LRU implementation
using namespace std;
using Pair = pair<int, int>;
class LRUCache{
unordered_map<int, deque<Pair>::iterator> cache;
deque<Pair> dq;
int capacity;
public:
LRUCache(int size) {
capacity = size;
}
void _resize(){
if(dq.size() > capacity){
Pair last = dq.back();
dq.pop_back();
cache.erase(last.first);
}
}
bool put(int key, int value){
if(cache.find(key) == cache.end()){
dq.push_front(Pair(key, value));
cache[key] = dq.begin();
_resize();
return true;
}
return false;
}
int get(int key) {
if(cache.find(key) != cache.end()) {
deque<Pair>::iterator dit = cache[key];
int value = (*dit).second;
dq.erase(dit);
dq.push_front(Pair(key, value));
cache[key] = dq.begin();
return value;
}
return -1;
}
};
Here’s my implementation. I just break the code up into helper functions so its a lot easier to understand.
public static class LRUCache {
public static class Node {
int key;
int val;
Node next;
Node prev;
public Node(int key, int val) {
this.key = key;
this.val = val;
}
}
final int capacity;
final HashMap<Integer, Node> cache;
public int size;
public Node head;
public Node tail;
public LRUCache(int capacity) {
this.capacity = capacity;
this.cache = new HashMap<>(capacity);
// Update head and tail
this.head = new Node(0, 0);
this.tail = new Node(0, 0);
this.head.next = tail;
this.tail.prev = head;
this.size = 0;
}
private void remove(Node node) {
node.prev.next = node.next;
node.next.prev = node.prev;
}
private void insertToHead(Node node) {
this.head.next.prev = node;
node.next = this.head.next;
this.head.next = node;
node.prev = this.head;
}
private void updateMostRecent(int key) {
Node node = this.cache.get(key);
remove(node);
insertToHead(node);
}
public Integer get(int key) {
if (!this.cache.containsKey(key)) {
return -1;
}
updateMostRecent(key);
return this.cache.get(key).val;
}
public void put(int key, int value) {
if (this.cache.containsKey(key)) {
updateMostRecent(key);
this.cache.get(key).val = value;
return;
}
if (size + 1 > capacity) {
Node lru = this.tail.prev;
remove(lru);
this.cache.remove(lru.key);
}
this.size++;
Node newNode = new Node(key, value);
insertToHead(newNode);
this.cache.put(key, newNode);
}
}
This part of the C++ solution is wrong. It causes a memory leak. With small tests the leak is not noticed.
if (size > capacity) {
DLL* lru = tail->prev;
cache.erase(lru->key);
tail->prev->val = tail->val;
tail->prev->next = nullptr;
tail = tail->prev;
size--;
}
The correct code for that part is below:
if (size > capacity) {
DLL* lru = tail->prev;
cache.erase(lru->key);
tail->prev = lru->prev;
lru->prev->next = tail;
delete lru;
size--;
}
My solution might be more readable. I use to helper functions of delete() and insert() which delete the node and insert right after the head respectively.
Pro tip: Always instantiate prev_node and next_node as I have done in my _delete() and _insert() functions to keep the logic clear.
from typing import List
class Node:
def __init__(self, key = None, val=None, next_node=None, prev_node=None):
self.next = next_node
self.prev = prev_node
self.val = val
self.key = key
class LruCache:
def __init__(self, size):
self.size = int(size)
self.cache = {}
self.head = Node()
self.tail = Node()
# connect the head and the tail as there are no other nodes in the beginning
self.head.next = self.tail
self.tail.prev = self.head
def get(self, key):
"""Gets the value of the node"""
if key not in self.cache:
return -1
node = self.cache[key]
self._delete(node)
self._insert(node)
return node.val
def _delete(self, node):
"""Removes the node from the doubly linked list."""
prev_node = node.prev
next_node = node.next
# connect the prev with the next
prev_node.next = next_node
next_node.prev = prev_node
def _insert(self, node):
"""Insert the node after head making it MRU."""
prev_node = self.head
next_node = self.head.next
# put node after head (prev_node) and before next_node
prev_node.next = node
node.next = next_node
next_node.prev = node
node.prev = prev_node
def put(self, key, val):
"""Add key with the value and update the priority list"""
if key in self.cache:
# key already exists, so just update the MRU and the value in the cache
self.cache[key].val = val
# remove the node and add it after the head
self._delete(self.cache[key])
self._insert(self.cache[key])
else:
# key does not exist, then check size and make room by deleting LRU (last node before tail)
if len(self.cache.keys()) == self.size:
# delete the key first as the self.tail.prev.key will change after self._delete()
del self.cache[self.tail.prev.key]
self._delete(self.tail.prev)
node = Node(key, val)
self._insert(node)
self.cache[key] = node