TopN热门新闻计算

采用flink实时计算topn热门新闻,主程序参考自http://wuchong.me/blog/2018/11/07/use-flink-calculate-hot-items/,稍微整理了一下程序,数据可以从参考链接中下载。

import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.java.io.PojoCsvInputFormat;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.functions.timestamps.AscendingTimestampExtractor;
import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
import org.apache.flink.util.Collector;

import java.io.File;
import java.net.URL;
import java.sql.Timestamp;
import java.util.*;
import java.util.concurrent.PriorityBlockingQueue;

/**

  • 每隔1分钟输出过去5分钟内点击量最多的前 N 个news
    */
    public class HotNews {

    public static void main(String[] args) throws Exception {
    StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
    env.setParallelism(1);
    //由于Java反射抽出的字段顺序是不确定的,需要显式指定字段顺序
    String[] fileOrder = new String[] {“userID”, “itemID”, “categoryID”, “behavior”, “timestamp”};
    String path = “userbehavior.csv”;
    DataStream dataSource = getSource(env, path, fileOrder,UserBehavior.class);
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); //显示设置按照eventtime模型式进行处理
    DataStream timeData = dataSource.assignTimestampsAndWatermarks(new AscendingTimestampExtractor() {
    //真实业务场景一般都是存在乱序的,所以一般使用 BoundedOutOfOrdernessTimestampExtractor
    @Override
    public long extractAscendingTimestamp(UserBehavior userBehavior) {
    return userBehavior.getTimestamp() * 1000; //转为毫秒
    }
    });
    DataStream clickData = timeData.filter(new FilterFunction() {
    @Override
    public boolean filter(UserBehavior userBehavior) throws Exception {
    return userBehavior.getBehavior().equals(“click”);
    }
    }); //过滤出点击
    //每个商品在每个窗口的点击量的数据流
    DataStream windowData = clickData
    .keyBy(“itemID”)
    .timeWindow(Time.minutes(5), Time.minutes(1)) //每隔1分统计最近5分钟内的每个news的点击量
    .aggregate(new CountAgg(), new WindowResultFunction());
    //计算每个窗口的最热门新闻
    DataStream topItems = windowData
    .keyBy(“windowEnd”)
    .process(new TopNHotNews(5));
    topItems.print();
    env.execute(“Hot news Job!”);
    }

    public static DataStream getSource(StreamExecutionEnvironment env, String path, String[] fileOrder,
    Class type) {
    //本地文件路径
    URL fileUrl = HotNews.class.getClassLoader().getResource(path);
    Path filePath = Path.fromLocalFile(new File(fileUrl.getPath()));
    //抽取TypeInformation,是一个PojoTypeInfo
    PojoTypeInfo pojoType = (PojoTypeInfo) TypeExtractor.createTypeInfo(type);
    //由于Java反射抽出的字段顺序是不确定的,需要显式指定字段顺序
    // 创建 PojoCsvInputFormat
    PojoCsvInputFormat csvInput = new PojoCsvInputFormat<>(filePath, pojoType, fileOrder);
    return env.createInput(csvInput,pojoType);
    }

    /** COUNT 统计的聚合函数实现,每出现一条记录加1**/
    public static class CountAgg implements AggregateFunction<UserBehavior, Long, Long> {
    @Override
    public Long createAccumulator() {
    return 0L;
    }
    @Override
    public Long add(UserBehavior userBehavior, Long acc) {
    return acc + 1;
    }
    @Override
    public Long getResult(Long acc) {
    return acc;
    }
    @Override
    public Long merge(Long acc1, Long acc2) {
    return acc1 + acc2;
    }
    }

    public static class WindowResultFunction implements WindowFunction<Long, ItemViewCount, Tuple, TimeWindow> {
    @Override
    public void apply(Tuple key, TimeWindow window, Iterable aggregateResult,
    Collector collector) throws Exception {
    long itemID = ((Tuple1) key).f0;
    long count = aggregateResult.iterator().next();
    collector.collect(ItemViewCount.of(itemID, window.getEnd(), count));
    }
    }

    /** 求某个窗口中前 N 名的热门点击新闻,key 为窗口时间戳,输出为 TopN 的结果字符串 */
    public static class TopNHotNews extends KeyedProcessFunction<KeyedProcessFunction.Context, ItemViewCount,
    String> {
    private final int topSize;
    public TopNHotNews(int topSize) {
    this.topSize = topSize;
    }
    // 用于存储商品与点击数的状态,待收齐同一个窗口的数据后,再触发 TopN 计算
    private ListState itemState;
    @Override
    public void open(Configuration parameters) throws Exception {
    super.open(parameters);
    //状态注册
    ListStateDescriptor itemsStateDesc = new ListStateDescriptor(“itemState-state”, ItemViewCount.class);
    itemState = getRuntimeContext().getListState(itemsStateDesc);
    }
    @Override
    public void processElement(ItemViewCount input, Context context, Collector collector) throws Exception {
    //每条数据都保存到状态中
    itemState.add(input);
    //注册 windowEnd+1 的 EventTime Timer, 当触发时,说明收齐了属于windowEnd窗口的所有商品数据 context.timerService().registerEventTimeTimer(input.getWindowEnd() + 1);
    }
    @Override
    public void onTimer(long timestamp, OnTimerContext cts, Collector out) throws Exception {
    //这时使用prorityqueue,小顶堆找topn大
    PriorityBlockingQueue topNItems = new PriorityBlockingQueue<>(topSize, new Comparator() {
    @Override
    public int compare(ItemViewCount o1, ItemViewCount o2) {
    return (int) (o1.getViewCount() - o2.getViewCount());
    }
    });
    for(ItemViewCount item: itemState.get()) {
    if(topNItems.size() < topSize) {
    topNItems.offer(item);
    } else if(topNItems.peek().getViewCount() < item.getViewCount()) {
    topNItems.poll();
    topNItems.offer(item);
    }
    }
    List list = InitCollect.newArrayList(topNItems.size());
    list.addAll(topNItems);
    Collections.sort(list, new Comparator() {
    @Override
    public int compare(ItemViewCount o1, ItemViewCount o2) {
    return (int)(o2.getViewCount() - o1.getViewCount());
    }
    });
    itemState.clear();
    //将排名信息格式化成String,便于打印
    StringBuilder result = new StringBuilder();
    result.append(“=============================\n”);
    result.append(“时间是:”).append(new Timestamp(timestamp - 1)).append(“\n”);
    for(ItemViewCount item: list) {
    result.append(item.getItemID()).append(" : “).append(item.getViewCount()).append(”\n");
    }
    out.collect(result.toString());
    }
    }
    }
    public class UserBehavior {
    private long userID; //userid
    private long itemID; //newsid
    private int categoryID; //news categor
    private String behavior; //user behavior ->点击、暴光、评论、转发(click,exposure,comment,forwarding)
    private long timestamp; //行为发生的时间戳,秒

    public long getUserID() {
    return userID;
    }

    public void setUserID(long userID) {
    this.userID = userID;
    }

    public long getItemID() {
    return itemID;
    }

    public void setItemID(long itemID) {
    this.itemID = itemID;
    }

    public int getCategoryID() {
    return categoryID;
    }

    public void setCategoryID(int categoryID) {
    this.categoryID = categoryID;
    }

    public String getBehavior() {
    return behavior;
    }

    public void setBehavior(String behavior) {
    this.behavior = behavior;
    }

    public long getTimestamp() {
    return timestamp;
    }

    public void setTimestamp(long timestamp) {
    this.timestamp = timestamp;
    }
    }
    /**

  • 新闻点击量,窗口操作的输出类型
    */
    public class ItemViewCount {
    private long itemID;//新闻id
    private long windowEnd;//窗口结束时间戳
    private long viewCount;//新闻点击量

    public long getItemID() {
    return itemID;
    }

    public void setItemID(long itemID) {
    this.itemID = itemID;
    }

    public long getWindowEnd() {
    return windowEnd;
    }

    public void setWindowEnd(long windowEnd) {
    this.windowEnd = windowEnd;
    }

    public long getViewCount() {
    return viewCount;
    }

    public void setViewCount(long viewCount) {
    this.viewCount = viewCount;
    }

    public static ItemViewCount of(long itemID, long windowEnd, long viewCount) {
    ItemViewCount itemViewCount = new ItemViewCount();
    itemViewCount.itemID = itemID;
    itemViewCount.windowEnd = windowEnd;
    itemViewCount.viewCount = viewCount;
    return itemViewCount;
    }
    }
    import java.util.*;
    public class InitCollect {

    public static List newArrayList() {
    return new ArrayList<>();
    }

    public static List newArrayList(int size) {
    return new ArrayList<>(size);
    }

    public static List newLinkedList() {
    return new LinkedList<>();
    }

    public static <K,V> Map<K,V> newHashMap() {
    return new HashMap<>();
    }

    public static Set newHashSet() {
    return new HashSet<>();
    }
    }

import random import logging logging.basicConfig(level=logging.INFO) import torch from pytorch_transformers import GPT2Tokenizer from pytorch_transformers import GPT2LMHeadModel # 选择 top-k 的函数的实现, def select_top_k(predictions, k=10): predicted_index = random.choice( predictions[0, -1, :].sort(descending=True)[1][:10]).item() return predicted_index # 载入预训练模型的分词器 tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 使用 GPT2Tokenizer 对输入进行编码 text = "Yesterday, a man named Jack said he saw an alien," indexed_tokens = tokenizer.encode(text) tokens_tensor = torch.tensor([indexed_tokens]) # 读取 GPT-2 预训练模型 model = GPT2LMHeadModel.from_pretrained("gpt2") model.eval() total_predicted_text = text n = 100 # 预测过程的循环次数 for _ in range(n): with torch.no_grad():#使用 torch.no_grad() 上下文管理器来关闭梯度计算,因为这个循环只是用于生成文本,不需要进行模型的参数更新 outputs = model(tokens_tensor)#调用 GPT-2 模型,将 tokens_tensor 输入到模型中,并获得模型的输出 outputs predictions = outputs[0] predicted_index = select_top_k(predictions, k=10)#调用函数,从预测结果中选择概率最高的前 10 个元素之一作为下一个词的索引 predicted_index predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])#使用 tokenizer.decode() 将索引转换为文本,将其添加到 total_predicted_text 中 total_predicted_text += tokenizer.decode(predicted_index) if '<|endoftext|>' in total_predicted_text: # 如果出现文本结束标志,就结束文本生成 break indexed_tokens += [predicted_index] tokens_tensor = torch.tensor([indexed_tokens]) print(total_predicted_text)优化这段代码,使输出文本为新闻
06-07
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值