n, k = map(int, input().split())
mod = int(1e9 + 7)
print(sum(pow(i, k, mod) for i in range(1,n + 1)) ** 2 % mod)