n, k = map(int, input().split()) mod = int(1e9 + 7) print(sum(pow(i, k, mod) for i in range(1,n + 1)) ** 2 % mod)