Mahout推荐引擎
什么是推荐算法?
乍一听这个词汇,我想大家最能联想到的就是抖音短视频,还有淘宝拼多多的一些商品推荐。
推荐算法最早在1992年就提出来了,但是火起来实际上是最近这些年的事情,因为互联网的爆发,有了更大的数据量可以供我们使用,推荐算法才有了很大的用武之地。
最开始,在网上找资料,其实都是进的yahoo,然后分门别类的点进去,找到你想要的东西,这是一个人工过程,到后来,我们用google,直接搜索自己需要的内容,这些都可以比较精准的找到你想要的东西,但是,如果我自己都不知道自己要找什么肿么办?最典型的例子就是,如果我打开豆瓣找电影,或者我去买说,我实际上不知道我想要买什么或者看什么,这时候推荐系统就可以派上用场了。
推荐算法的概念:
推荐算法是计算机专业中的一种算法,主要运用于网络中,所谓推荐算法就是利用用户的一些行为,通过一些数学算法,推测出用户可能喜欢的东西。
Mahout推荐算法中的协同过滤算法 包括两种 基于物品的协同过滤和基于用户的协同过滤(物以类聚,人以群分),个人主要研究了基于用户的协同过滤的推荐引擎的使用。
相似度算法
协同过滤算法中最为关键的就是如何去计算物品的相似度和用户的(喜好)相似度呢?
他们的算法原理有以下几种,我在后面的mahout调用选择的是皮尔逊相关系数的算法接口。
一、基于item(商品)的协同过滤(ItemCF)
基于item的协同过滤,通过用户对不同item的评分来评测item之间的相似性,基于item之间的相似性做出推荐。简单来讲就是:给用户推荐和他之前喜欢的物品相似的物品。
举个例子 小强购买了硬盘,那么在基于item的协同过滤中就很有可能推荐U盘给他,因为U盘和硬盘相似度很高。
二、基于用户的协同过滤
假设我们要对用户小强进行推荐,首先要找到与用户石强峰相似的用户,然后看看这些用户购买的什么商品,将这些商品推荐给小强。
根据与用户u相似的其他用户对商品i的评分,来推断用户u对商品i的评分。
Ps(这里的评分也同上简化为 购买商品后订单的评分,实际大家可以当作是浏览时长、是否点赞、是否收藏、购买后评分这些然后根据公式算出来的的一个综合评分也就是喜好度) 假设喜好度满分为100分
举个例子 小强喜好度(铅笔95、毛笔90)(小强可能爱好写字书法)
小刘1号喜好度(铅笔94、毛笔91、水笔90)(小刘1号可能爱好书法鞋子)
小刘2号喜好度(铅笔90、毛笔92、耳机60)(小刘2号可能爱好书法?并且不太喜欢耳机)
小刘3号喜好度(铅笔30、毛笔15、滑板95)(小刘3号可能不爱书法,可能偏爱运动?)
这个时候 首先利用第一个公式可以得出,小强和小刘1号和小刘2号是最相似的(因为对铅笔毛笔的喜好度最为相似),而小强跟小刘3号的相似度并不高。
这个时候就可以利用上图第二个公式,来计算出 石强峰可能对水笔的喜好度是比较高的,对耳机的喜好度可能是比较低的。但是在真正的推荐算法中,数据量是非常非常多的,我这里只做一个简单描述。
demo整合
我将我的demo步骤分为如下:
1 引入maven依赖
2 推荐算法肯定是需要大量数据的,所以需要建立数据模型(本地生成模拟数据/数据库中拿出)
3 将数据导入数据库
3 从数据库中构建file数据模型
4 构建mahout推荐接口
5 测试输出结果
6 渲染前端页面
涉及的技术点包括:我对ID生成器的选择、对mahout中的相似度算法的选择、数据放入数据库和从数据库中取出等。
首先在idea中搭建环境配置,引入maven依赖。 这是mahout要用到的依赖
<dependency>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout-core</artifactId>
<exclusions>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
</exclusion>
</exclusions>
<version>0.9</version>
</dependency>
接下来就是计算用户相似度了,想要计算用户的相似度,肯定是需要大量数据的,我这里用的id生成器+for循环随机生成并且在本地写入成txt文件。然后一开始我用的是UUID,但是做到后面使用mahout皮尔逊算法的接口的时候发现ID必须得是long类型,不能是String,为了解决这个问题,我一开始想的是将string转换long,发现UUID生成的大量字母,不可取。然后又去了mahout的皮尔逊算法接口那里,想要看看能不能修改他这个有参方法的id参数类型,去了发现是只读状态,无法修改,很无奈,只能换一种能生成long类型的ID生成器了,网上查了资料,最后我选择的是基于snowflake算法的ID生成器。那么接下来开始建立数据模型了。(Snowflake我就不详细讲了,网上都是能搜到的。)
下面是注入snowflake的代码
package com.chixing.util;
import org.springframework.stereotype.Component;
@Component
public class SnowflakeIdGenerator {
// ==============================Fields===========================================
/** 开始时间截 (2018-04-08) */
private static final long TWEPOCH = 1523116800000L;
/** 机器id所占的位数 */
private static final long WORKERID_BITS = 5L;
/** 数据标识id所占的位数 */
private static final long DATACENTERID_BITS = 5L;
/** 支持的最大机器id,结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数) */
private static final long MAX_WORKER_ID = -1L ^ (-1L << WORKERID_BITS);
/** 支持的最大数据标识id,结果是31 */
private static final long MAX_DATACENTER_ID = -1L ^ (-1L << DATACENTERID_BITS);
/** 序列在id中占的位数 */
private static final long SEQUENCE_BITS = 12L;
/** 机器ID向左移12位 */
private static final long WORKERID_SHIFT = SEQUENCE_BITS;
/** 数据标识id向左移17位(12+5) */
private static final long DATACENTERID_SHIFT = SEQUENCE_BITS + WORKERID_BITS;
/** 时间截向左移22位(5+5+12) */
private static final long TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKERID_BITS + DATACENTERID_BITS;
/** 生成序列的掩码,这里为4095 (0b111111111111=0xfff=4095) */
private static final long SEQUENCE_MASK = -1L ^ (-1L << SEQUENCE_BITS);
/** 工作机器ID(0~31) */
private long workerId;
/** 数据中心ID(0~31) */
private long datacenterId;
/** 毫秒内序列(0~4095) */
private long sequence = 0L;
/** 上次生成ID的时间截 */
private long lastTimestamp = -1L;
//==============================Constructors=====================================
/**
* 构造函数
*/
public SnowflakeIdGenerator() {
}
/**
* 构造函数
* @param workerId 工作ID (0~31)
* @param datacenterId 数据中心ID (0~31)
*/
public SnowflakeIdGenerator(long workerId, long datacenterId) {
if (workerId > MAX_WORKER_ID || workerId < 0) {
throw new IllegalArgumentException(String.format("worker Id can't be greater than %d or less than 0", MAX_WORKER_ID));
}
if (datacenterId > MAX_DATACENTER_ID || datacenterId < 0) {
throw new IllegalArgumentException(String.format("datacenter Id can't be greater than %d or less than 0", MAX_DATACENTER_ID));
}
this.workerId = workerId;
this.datacenterId = datacenterId;
}
// ==============================Methods==========================================
/**
* 获得下一个ID (该方法是线程安全的)
* @return SnowflakeId
*/
public synchronized long nextId() {
long timestamp = timeGen();
//如果当前时间小于上一次ID生成的时间戳,说明系统时钟回退过这个时候应当抛出异常
if (timestamp < lastTimestamp) {
throw new IllegalStateException(
String.format("Clock moved backwards. Refusing to generate id for %d milliseconds", lastTimestamp - timestamp));
}
//如果是同一时间生成的,则进行毫秒内序列
if (lastTimestamp == timestamp) {
sequence = (sequence + 1) & SEQUENCE_MASK;
//毫秒内序列溢出
if (sequence == 0) {
//阻塞到下一个毫秒,获得新的时间戳
timestamp = tilNextMillis(lastTimestamp);
}
}
//时间戳改变,毫秒内序列重置
else {
sequence = 0L;
}
//上次生成ID的时间截
lastTimestamp = timestamp;
//移位并通过或运算拼到一起组成64位的ID
return ((timestamp - TW***TAMP_LEFT_SHIFT) //
| (datacenterId << DATACENTERID_SHIFT) //
| (workerId << WORKERID_SHIFT) //
| sequence;
}
/**
* 阻塞到下一个毫秒,直到获得新的时间戳
* @param lastTimestamp 上次生成ID的时间截
* @return 当前时间戳
*/
protected long tilNextMillis(long lastTimestamp) {
long timestamp = timeGen();
while (timestamp <= lastTimestamp) {
timestamp = timeGen();
}
return timestamp;
}
/**
* 返回以毫秒为单位的当前时间
* @return 当前时间(毫秒)
*/
protected long timeGen() {
return System.currentTimeMillis();
}
//==============================Test=============================================
/** 测试 */
/**public static void main(String[] args) {
SnowflakeIdGenerator idWorker = new SnowflakeIdGenerator(0, 0);
for (int i = 0; i < 1000; i++) {
long id = idWorker.nextId();
System.out.println(Long.toBinaryString(id));
System.out.println(id);
}
}*/
}
本地生成模拟数据
public class Data {
//数据格式
public static void main(String[] args) throws FileNotFoundException {
Random random = new Random();
LinkedList<Long> userId = new LinkedList<>();
LinkedList<Long> itemId = new LinkedList<>();
//拼接用户关联Item的数据
StringBuffer data = new StringBuffer();
//拼接用户
StringBuffer user = new StringBuffer();
//生成一千个Item
SnowflakeIdGenerator idWorker = new SnowflakeIdGenerator(0, 0);
for (int i = 0; i < 1000; i++) {
Long along = idWorker.nextId();
itemId.add(along);
}
//生成一百个用户
for (int i = 0; i < 100; i++) {
Long uId = idWorker.nextId();
userId.add(uId);
//拷贝一份Item列表
ArrayList<Long> itemArray = new ArrayList<>();
itemArray.addAll(itemId);
//给这个用户 随机关联一些Item 随机数量为200以内
int count = random.nextInt(200);
for (int j = 0; j < count; j++) {
//随机取一个角标
int randomIndex = random.nextInt(itemArray.size());
Long id = itemArray.get(randomIndex);
itemArray.remove(randomIndex);
data.append(uId).append(",").append(id).append(",").append(
//随机生成100以内的喜好程度
Float.valueOf(random.nextInt(100)
)).append("\r\n");
}
user.append(uId).append(" count:").append(count).append("\r\n");
}
//存储用户关联Item的数据
FileOutputStream dataOut = null;
try {
dataOut = new FileOutputStream(new File("D:/data1/data.txt"));
} catch (FileNotFoundException e) {
e.printStackTrace();
}
//存储用户ID 和 他关联的数量 实际开发中 这个不需要存储 我这里存储只是为了后面方便测试
FileOutputStream userOut = new FileOutputStream(new File("D:/data1/user.txt"));
try {
dataOut.write(data.toString().getBytes());
} catch (IOException e) {
e.printStackTrace();
}
try {
userOut.write(data.toString().getBytes());
} catch (IOException e) {
e.printStackTrace();
}
}
}
controller层
@RestController
@RequestMapping("recommend")
public class MahoutController {
@GetMapping("/{id}") //http://localhost:8081/demo/recommend/386611047793951216
public List<RecommendedItem> recommend(@PathVariable("id") long id) {
/* //假设这是数据库
ClassPathResource classPathResource = new ClassPathResource("data.txt");
File file = classPathResource.getFile();
//假设lines 是从数据库读取出来的每一行数据
List<String> lines = Files.readLines(file, Charset.defaultCharset());
FastByIDMap<LinkedList<Preference>> linkedListFastByIDMap = new FastByIDMap<>();
int size = lines.size();
for (int i = 0; i < size; i++) {
String lien = lines.get(i);
if (lien.length() == 0) continue;
//按逗号分隔
String[] split = lien.split(",");
//用户ID
Long uId = Long.valueOf(split[0]);
//Item Id
Long itemId = Long.valueOf(split[1]);
//喜好程度
Float value = Float.valueOf(split[2]);
//创建该用户下的item list
LinkedList<Preference> preferences = linkedListFastByIDMap.get(uId);
if (preferences == null) preferences = new LinkedList<>();
//构建偏好
GenericPreference genericPreference = new GenericPreference(uId, itemId, value);
preferences.add(genericPreference);
//重新写入map集合
linkedListFastByIDMap.put(uId, preferences);
}
//遍历map集合
Set<Map.Entry<Long, LinkedList<Preference>>> entries = linkedListFastByIDMap.entrySet();
// 构建偏好集合
FastByIDMap<PreferenceArray> arrayFastByIDMap = new FastByIDMap<>();
for (Map.Entry<Long, LinkedList<Preference>> entry : entries) {
Long key = entry.getKey();
LinkedList<Preference> value = entry.getValue();
arrayFastByIDMap.put(key, new GenericUserPreferenceArray(value));
}
//最后 构建数据模型
DataModel dataModel = new GenericDataModel(arrayFastByIDMap);
System.out.println("进入方法");*/
//首先拿到从数据库中输出到本地的数据模型
DataModel dataModel = null;
{
try {
dataModel = new FileDataModel(new File("D:/data1/data.txt"));
} catch (IOException e) {
e.printStackTrace();
}
}
UserSimilarity userSimilarity = null;
try {
userSimilarity = new PearsonCorrelationSimilarity(dataModel);
} catch (TasteException e) {
e.printStackTrace();
}
//构建 近邻对象 threshold 是相似阈值 这个数值越高 推荐精准越高 但是推荐的数据也越少 最高为 你给用户设置的喜好值最高值 也就是preference的最高值
float threshold = 0f;
UserNeighborhood neighborhood = new ThresholdUserNeighborhood(threshold, userSimilarity, dataModel);
//构建推荐器
UserBasedRecommender recommender = new GenericUserBasedRecommender(dataModel, neighborhood, userSimilarity);
//开始推荐 一参数为用户ID 二参数为 要推荐Item数量
//我随便在用户列表里拿一个ID试试
List<RecommendedItem> recommend = null;
try {
recommend = recommender.recommend(id, 10);
} catch (TasteException e) {
e.printStackTrace();
}
System.out.println("走到了这里");
return recommend;
}
}