题目链接

LRU Cache

题目描述

设计并实现一个 LRU (Least Recently Used - 最近最少使用) 缓存机制。它应该支持以下操作: getput

  • int get(int key): 如果键存在于缓存中,则获取键的值(总是正数),否则返回 -1。
  • void put(int key, int value): 如果键不存在,则写入键值对。如果键已存在,则更新其值。当缓存容量达到上限时,它应该在写入新数据之前删除最近最少使用的数据。

特殊规则:

  • get(key)put 一个key,都算作一次“使用”。
  • put 一个已存在key 来更新其值,不算作一次“使用”。

所有操作都需要在 时间复杂度内完成。

解题思路

为了在 时间内完成所有操作,我们需要组合使用两种数据结构:哈希表和双向链表。

  • 哈希表 (Hash Map): 用于存储 key 到链表节点的映射,这样我们就可以在 时间内通过 key 查找到缓存中对应的数据节点。
  • 双向链表 (Doubly Linked List): 用于维护数据的使用顺序。我们将最近使用的节点放在链表头部,最久未使用的节点放在链表尾部。双向链表的特性使得在头部插入和删除任意节点的操作都可以在 时间内完成。

1. 数据结构设计

  • 双向链表节点 DLinkedNode: 每个节点包含 key, value, prev 指针, next 指针。存储 key 是为了在删除尾部节点时,能够知道要从哈希表中删除哪个 key
  • 哈希表 cache: map<int, DLinkedNode*>,键是 key,值是对应的链表节点的指针。
  • 辅助成员: capacity 记录缓存容量,size 记录当前大小,以及两个哨兵节点 headtail 来简化链表操作。head 指向最近使用的节点,tail 指向最久未使用的节点。

2. 操作实现

  • moveToHead(node): 将一个节点移动到链表头部。这包括两步:先将节点从当前位置删除,再将其添加到头部。

  • get(key):

    1. 通过哈希表查找 key。如果不存在,返回 -1。
    2. 如果存在,获取对应的 node
    3. 这是一次“使用”,所以调用 moveToHead(node) 将该节点移动到链表头部。
    4. 返回 node->value
  • put(key, value):

    1. 通过哈希表查找 key
    2. 如果 key 存在:
      • 获取对应的 node
      • 更新 node->value = value
      • 根据特殊规则,不移动节点,因为这不算一次“使用”。
    3. 如果 key 不存在:
      • 创建一个新的 DLinkedNode
      • 这是一次“使用”,所以将新节点添加到链表头部。
      • {key, newNode} 存入哈希表。
      • 缓存大小 size 加一。
      • 检查容量: 如果 size > capacity,说明需要淘汰一个节点。
        • 获取链表尾部的节点(tail->prev),这是最近最少使用的节点。
        • 从链表中删除该节点。
        • 从哈希表中删除该节点的 key
        • 缓存大小 size 减一。

代码

#include <iostream>
#include <string>
#include <vector>
#include <unordered_map>

using namespace std;

struct DLinkedNode {
    int key, value;
    DLinkedNode* prev;
    DLinkedNode* next;
    DLinkedNode(): key(0), value(0), prev(nullptr), next(nullptr) {}
    DLinkedNode(int _key, int _value): key(_key), value(_value), prev(nullptr), next(nullptr) {}
};

class LRUCache {
private:
    unordered_map<int, DLinkedNode*> cache;
    DLinkedNode* head;
    DLinkedNode* tail;
    int size;
    int capacity;

public:
    LRUCache(int _capacity) : capacity(_capacity), size(0) {
        head = new DLinkedNode();
        tail = new DLinkedNode();
        head->next = tail;
        tail->prev = head;
    }
    
    ~LRUCache() {
        DLinkedNode* curr = head;
        while(curr) {
            DLinkedNode* next = curr->next;
            delete curr;
            curr = next;
        }
    }

    int get(int key) {
        if (!cache.count(key)) {
            return -1;
        }
        DLinkedNode* node = cache[key];
        moveToHead(node);
        return node->value;
    }

    void put(int key, int value) {
        if (!cache.count(key)) {
            // 新增
            DLinkedNode* newNode = new DLinkedNode(key, value);
            cache[key] = newNode;
            addToHead(newNode);
            ++size;
            if (size > capacity) {
                DLinkedNode* removed = removeTail();
                cache.erase(removed->key);
                delete removed;
                --size;
            }
        } else {
            // 更新
            DLinkedNode* node = cache[key];
            node->value = value;
            // 更新不算使用,不移动
        }
    }

private:
    void addToHead(DLinkedNode* node) {
        node->prev = head;
        node->next = head->next;
        head->next->prev = node;
        head->next = node;
    }

    void removeNode(DLinkedNode* node) {
        node->prev->next = node->next;
        node->next->prev = node->prev;
    }

    void moveToHead(DLinkedNode* node) {
        removeNode(node);
        addToHead(node);
    }

    DLinkedNode* removeTail() {
        DLinkedNode* node = tail->prev;
        removeNode(node);
        return node;
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    LRUCache cache(n);

    char op;
    while (cin >> op) {
        if (op == 'p') {
            int key, value;
            cin >> key >> value;
            cache.put(key, value);
        } else if (op == 'g') {
            int key;
            cin >> key;
            cout << cache.get(key) << endl;
        }
    }

    return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

class DLinkedNode {
    int key;
    int value;
    DLinkedNode prev;
    DLinkedNode next;

    public DLinkedNode() {}

    public DLinkedNode(int _key, int _value) {
        this.key = _key;
        this.value = _value;
    }
}

class LRUCache {
    private Map<Integer, DLinkedNode> cache = new HashMap<>();
    private int size;
    private int capacity;
    private DLinkedNode head, tail;

    public LRUCache(int capacity) {
        this.size = 0;
        this.capacity = capacity;
        head = new DLinkedNode();
        tail = new DLinkedNode();
        head.next = tail;
        tail.prev = head;
    }

    public int get(int key) {
        DLinkedNode node = cache.get(key);
        if (node == null) {
            return -1;
        }
        moveToHead(node);
        return node.value;
    }

    public void put(int key, int value) {
        DLinkedNode node = cache.get(key);
        if (node == null) {
            DLinkedNode newNode = new DLinkedNode(key, value);
            cache.put(key, newNode);
            addToHead(newNode);
            ++size;
            if (size > capacity) {
                DLinkedNode tailNode = removeTail();
                cache.remove(tailNode.key);
                --size;
            }
        } else {
            node.value = value;
            // 更新不算使用,不移动
        }
    }

    private void addToHead(DLinkedNode node) {
        node.prev = head;
        node.next = head.next;
        head.next.prev = node;
        head.next = node;
    }

    private void removeNode(DLinkedNode node) {
        node.prev.next = node.next;
        node.next.prev = node.prev;
    }

    private void moveToHead(DLinkedNode node) {
        removeNode(node);
        addToHead(node);
    }

    private DLinkedNode removeTail() {
        DLinkedNode res = tail.prev;
        removeNode(res);
        return res;
    }
}

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        LRUCache cache = new LRUCache(n);

        while (sc.hasNext()) {
            String op = sc.next();
            if (op.equals("p")) {
                int key = sc.nextInt();
                int value = sc.nextInt();
                cache.put(key, value);
            } else if (op.equals("g")) {
                int key = sc.nextInt();
                System.out.println(cache.get(key));
            }
        }
    }
}
import sys

class DLinkedNode:
    def __init__(self, key=0, value=0):
        self.key = key
        self.value = value
        self.prev = None
        self.next = None

class LRUCache:
    def __init__(self, capacity: int):
        self.cache = {}
        self.head = DLinkedNode()
        self.tail = DLinkedNode()
        self.head.next = self.tail
        self.tail.prev = self.head
        self.capacity = capacity
        self.size = 0

    def get(self, key: int) -> int:
        if key not in self.cache:
            return -1
        node = self.cache[key]
        self._move_to_head(node)
        return node.value

    def put(self, key: int, value: int) -> None:
        if key not in self.cache:
            # 新增
            node = DLinkedNode(key, value)
            self.cache[key] = node
            self._add_to_head(node)
            self.size += 1
            if self.size > self.capacity:
                removed = self._remove_tail()
                del self.cache[removed.key]
                self.size -= 1
        else:
            # 更新
            node = self.cache[key]
            node.value = value
            # 更新不算使用,不移动

    def _add_to_head(self, node):
        node.prev = self.head
        node.next = self.head.next
        self.head.next.prev = node
        self.head.next = node

    def _remove_node(self, node):
        node.prev.next = node.next
        node.next.prev = node.prev

    def _move_to_head(self, node):
        self._remove_node(node)
        self._add_to_head(node)

    def _remove_tail(self):
        node = self.tail.prev
        self._remove_node(node)
        return node

def main():
    try:
        n_str = sys.stdin.readline()
        if not n_str: return
        n = int(n_str)
        cache = LRUCache(n)
        
        for line in sys.stdin:
            parts = line.split()
            op = parts[0]
            if op == 'p':
                key, value = int(parts[1]), int(parts[2])
                cache.put(key, value)
            elif op == 'g':
                key = int(parts[1])
                print(cache.get(key))

    except (IOError, ValueError):
        return

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:哈希表 + 双向链表

  • 时间复杂度: getput 操作中的所有步骤,包括哈希表查找、链表节点的增删和移动,都是常数时间操作。

  • 空间复杂度: ,其中 C 是缓存的容量。哈希表和双向链表最多存储 C 个元素。