package org.apache.flink.graph.library.linkanalysis;

import java.util.Iterator;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.aggregators.DoubleSumAggregator;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.base.ReduceOperatorBase;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.asm.degree.annotate.directed.EdgeSourceDegrees;
import org.apache.flink.graph.asm.degree.annotate.directed.VertexDegrees;
import org.apache.flink.graph.asm.result.PrintableResult;
import org.apache.flink.graph.asm.result.UnaryResultBase;
import org.apache.flink.graph.library.linkanalysis.Functions;
import org.apache.flink.graph.utils.GraphUtils;
import org.apache.flink.graph.utils.MurmurHash;
import org.apache.flink.graph.utils.proxy.GraphAlgorithmWrappingBase;
import org.apache.flink.graph.utils.proxy.GraphAlgorithmWrappingDataSet;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.apache.jena.atlas.lib.Chars;

/* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank.class */
public class PageRank<K, VV, EV> extends GraphAlgorithmWrappingDataSet<K, VV, EV, Result<K>> {
    private static final String VERTEX_COUNT = "vertex count";
    private static final String SUM_OF_SCORES = "sum of scores";
    private static final String CHANGE_IN_SCORES = "change in scores";
    private final double dampingFactor;
    private int maxIterations;
    private double convergenceThreshold;
    private boolean includeZeroDegreeVertices;

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$AdjustScores.class */
    private static class AdjustScores<T> extends RichMapFunction<Tuple2<T, DoubleValue>, Tuple2<T, DoubleValue>> {
        private double dampingFactor;
        private long vertexCount;
        private double uniformlyDistributedScore;

        public AdjustScores(double d) {
            this.dampingFactor = d;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            Iterator it = getRuntimeContext().getBroadcastVariable(PageRank.SUM_OF_SCORES).iterator();
            double value = 1.0d - (it.hasNext() ? ((DoubleValue) ((Tuple2) it.next()).f1).getValue() : CMAESOptimizer.DEFAULT_STOPFITNESS);
            Iterator it2 = getRuntimeContext().getBroadcastVariable(PageRank.VERTEX_COUNT).iterator();
            this.vertexCount = it2.hasNext() ? ((LongValue) it2.next()).getValue() : 0L;
            this.uniformlyDistributedScore = ((1.0d - this.dampingFactor) + (this.dampingFactor * value)) / this.vertexCount;
        }

        @Override // org.apache.flink.api.common.functions.RichMapFunction, org.apache.flink.api.common.functions.MapFunction
        public Tuple2<T, DoubleValue> map(Tuple2<T, DoubleValue> tuple2) throws Exception {
            tuple2.f1.setValue(this.uniformlyDistributedScore + (this.dampingFactor * tuple2.f1.getValue()));
            return tuple2;
        }
    }

    @FunctionAnnotation.ForwardedFieldsFirst({"0"})
    @FunctionAnnotation.ForwardedFieldsSecond({"*"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$ChangeInScores.class */
    private static class ChangeInScores<T> extends RichJoinFunction<Tuple2<T, DoubleValue>, Tuple2<T, DoubleValue>, Tuple2<T, DoubleValue>> {
        private double changeInScores;

        private ChangeInScores() {
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.changeInScores = CMAESOptimizer.DEFAULT_STOPFITNESS;
        }

        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void close() throws Exception {
            super.close();
            ((DoubleSumAggregator) getIterationRuntimeContext().getIterationAggregator(PageRank.CHANGE_IN_SCORES)).aggregate(this.changeInScores);
        }

        @Override // org.apache.flink.api.common.functions.RichJoinFunction, org.apache.flink.api.common.functions.JoinFunction
        public Tuple2<T, DoubleValue> join(Tuple2<T, DoubleValue> tuple2, Tuple2<T, DoubleValue> tuple22) throws Exception {
            this.changeInScores += Math.abs(tuple22.f1.getValue() - tuple2.f1.getValue());
            return tuple22;
        }
    }

    @FunctionAnnotation.ForwardedFields({"0; 1"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$ExtractSourceDegree.class */
    private static class ExtractSourceDegree<T, ET> implements MapFunction<Edge<T, Tuple2<ET, VertexDegrees.Degrees>>, Edge<T, LongValue>> {
        Edge<T, LongValue> output;

        private ExtractSourceDegree() {
            this.output = new Edge<>();
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v9, types: [org.apache.flink.types.LongValue, T2] */
        @Override // org.apache.flink.api.common.functions.MapFunction
        public Edge<T, LongValue> map(Edge<T, Tuple2<ET, VertexDegrees.Degrees>> edge) throws Exception {
            this.output.f0 = edge.f0;
            this.output.f1 = edge.f1;
            this.output.f2 = ((VertexDegrees.Degrees) ((Tuple2) edge.f2).f1).getOutDegree();
            return this.output;
        }
    }

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$InitializeSourceVertices.class */
    private static class InitializeSourceVertices<T> implements FlatMapFunction<Vertex<T, VertexDegrees.Degrees>, Tuple2<T, DoubleValue>> {
        private Tuple2<T, DoubleValue> output;

        private InitializeSourceVertices() {
            this.output = new Tuple2<>(null, new DoubleValue(CMAESOptimizer.DEFAULT_STOPFITNESS));
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.flink.api.common.functions.FlatMapFunction
        public void flatMap(Vertex<T, VertexDegrees.Degrees> vertex, Collector<Tuple2<T, DoubleValue>> collector) throws Exception {
            if (((VertexDegrees.Degrees) vertex.f1).getInDegree().getValue() == 0) {
                this.output.f0 = vertex.f0;
                collector.collect(this.output);
            }
        }
    }

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$InitializeVertexScores.class */
    private static class InitializeVertexScores<T> extends RichMapFunction<Vertex<T, VertexDegrees.Degrees>, Tuple2<T, DoubleValue>> {
        private Tuple2<T, DoubleValue> output;

        private InitializeVertexScores() {
            this.output = new Tuple2<>();
        }

        /* JADX WARN: Type inference failed for: r1v2, types: [org.apache.flink.types.DoubleValue, T1] */
        @Override // org.apache.flink.api.common.functions.AbstractRichFunction, org.apache.flink.api.common.functions.RichFunction
        public void open(Configuration configuration) throws Exception {
            super.open(configuration);
            this.output.f1 = new DoubleValue(getRuntimeContext().getBroadcastVariable(PageRank.VERTEX_COUNT).iterator().hasNext() ? 1.0d / ((LongValue) r0.next()).getValue() : Double.NaN);
        }

        @Override // org.apache.flink.api.common.functions.RichMapFunction, org.apache.flink.api.common.functions.MapFunction
        public Tuple2<T, DoubleValue> map(Vertex<T, VertexDegrees.Degrees> vertex) throws Exception {
            this.output.f0 = vertex.f0;
            return this.output;
        }
    }

    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$Result.class */
    public static class Result<T> extends UnaryResultBase<T> implements PrintableResult {
        private DoubleValue pageRankScore;
        public static final int HASH_SEED = 1074835241;
        private transient MurmurHash hasher;

        public DoubleValue getPageRankScore() {
            return this.pageRankScore;
        }

        public void setPageRankScore(DoubleValue doubleValue) {
            this.pageRankScore = doubleValue;
        }

        @Override // org.apache.flink.graph.asm.result.ResultBase
        public String toString() {
            return "(" + getVertexId0() + Chars.S_COMMA + this.pageRankScore + ")";
        }

        @Override // org.apache.flink.graph.asm.result.PrintableResult
        public String toPrintableString() {
            return "Vertex ID: " + getVertexId0() + ", PageRank score: " + this.pageRankScore;
        }

        public int hashCode() {
            if (this.hasher == null) {
                this.hasher = new MurmurHash(1074835241);
            }
            return this.hasher.reset().hash(getVertexId0().hashCode()).hash(this.pageRankScore.getValue()).hash();
        }
    }

    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$ScoreConvergence.class */
    private static class ScoreConvergence implements ConvergenceCriterion<DoubleValue> {
        private double convergenceThreshold;

        public ScoreConvergence(double d) {
            this.convergenceThreshold = d;
        }

        @Override // org.apache.flink.api.common.aggregators.ConvergenceCriterion
        public boolean isConverged(int i, DoubleValue doubleValue) {
            return doubleValue.getValue() <= this.convergenceThreshold;
        }
    }

    @FunctionAnnotation.ForwardedFieldsSecond({"1->0"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$SendScore.class */
    private static class SendScore<T> implements CoGroupFunction<Tuple2<T, DoubleValue>, Edge<T, LongValue>, Tuple2<T, DoubleValue>> {
        private Tuple2<T, DoubleValue> output;

        private SendScore() {
            this.output = new Tuple2<>(null, new DoubleValue());
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v1, types: [T0, T1] */
        /* JADX WARN: Type inference failed for: r1v13, types: [T0, T1] */
        @Override // org.apache.flink.api.common.functions.CoGroupFunction
        public void coGroup(Iterable<Tuple2<T, DoubleValue>> iterable, Iterable<Edge<T, LongValue>> iterable2, Collector<Tuple2<T, DoubleValue>> collector) throws Exception {
            Iterator<Edge<T, LongValue>> it = iterable2.iterator();
            if (it.hasNext()) {
                this.output.f0 = it.next().f1;
                this.output.f1.setValue(iterable.iterator().next().f1.getValue() / ((LongValue) r0.f2).getValue());
                collector.collect(this.output);
                while (it.hasNext()) {
                    this.output.f0 = it.next().f1;
                    collector.collect(this.output);
                }
            }
        }
    }

    @FunctionAnnotation.ForwardedFields({"0"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$SumVertexScores.class */
    private static class SumVertexScores<T> implements ReduceFunction<Tuple2<T, DoubleValue>> {
        private SumVertexScores() {
        }

        @Override // org.apache.flink.api.common.functions.ReduceFunction
        public Tuple2<T, DoubleValue> reduce(Tuple2<T, DoubleValue> tuple2, Tuple2<T, DoubleValue> tuple22) throws Exception {
            tuple2.f1.setValue(tuple2.f1.getValue() + tuple22.f1.getValue());
            return tuple2;
        }
    }

    @FunctionAnnotation.ForwardedFields({"0->vertexId0; 1->pageRankScore"})
    /* loaded from: input_file:org/apache/flink/graph/library/linkanalysis/PageRank$TranslateResult.class */
    private static class TranslateResult<T> implements MapFunction<Tuple2<T, DoubleValue>, Result<T>> {
        private Result<T> output;

        private TranslateResult() {
            this.output = new Result<>();
        }

        @Override // org.apache.flink.api.common.functions.MapFunction
        public Result<T> map(Tuple2<T, DoubleValue> tuple2) throws Exception {
            this.output.setVertexId0(tuple2.f0);
            this.output.setPageRankScore(tuple2.f1);
            return this.output;
        }
    }

    public PageRank(double d, int i) {
        this(d, i, Double.MAX_VALUE);
    }

    public PageRank(double d, double d2) {
        this(d, Integer.MAX_VALUE, d2);
    }

    public PageRank(double d, int i, double d2) {
        this.includeZeroDegreeVertices = false;
        Preconditions.checkArgument(CMAESOptimizer.DEFAULT_STOPFITNESS < d && d < 1.0d, "Damping factor must be between zero and one");
        Preconditions.checkArgument(i > 0, "Number of iterations must be greater than zero");
        Preconditions.checkArgument(d2 > CMAESOptimizer.DEFAULT_STOPFITNESS, "Convergence threshold must be greater than zero");
        this.dampingFactor = d;
        this.maxIterations = i;
        this.convergenceThreshold = d2;
    }

    public PageRank<K, VV, EV> setIncludeZeroDegreeVertices(boolean z) {
        this.includeZeroDegreeVertices = z;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.graph.utils.proxy.GraphAlgorithmWrappingBase
    public boolean canMergeConfigurationWith(GraphAlgorithmWrappingBase graphAlgorithmWrappingBase) {
        if (!super.canMergeConfigurationWith(graphAlgorithmWrappingBase)) {
            return false;
        }
        PageRank pageRank = (PageRank) graphAlgorithmWrappingBase;
        return this.dampingFactor == pageRank.dampingFactor && this.includeZeroDegreeVertices == pageRank.includeZeroDegreeVertices;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.flink.graph.utils.proxy.GraphAlgorithmWrappingBase
    public void mergeConfiguration(GraphAlgorithmWrappingBase graphAlgorithmWrappingBase) {
        super.mergeConfiguration(graphAlgorithmWrappingBase);
        PageRank pageRank = (PageRank) graphAlgorithmWrappingBase;
        this.maxIterations = Math.max(this.maxIterations, pageRank.maxIterations);
        this.convergenceThreshold = Math.min(this.convergenceThreshold, pageRank.convergenceThreshold);
    }

    @Override // org.apache.flink.graph.utils.proxy.GraphAlgorithmWrappingDataSet
    public DataSet<Result<K>> runInternal(Graph<K, VV, EV> graph) throws Exception {
        Operator operator;
        DataSet dataSet = (DataSet) graph.run(new VertexDegrees().setIncludeZeroDegreeVertices(this.includeZeroDegreeVertices).setParallelism(this.parallelism));
        DataSet<LongValue> count = GraphUtils.count(dataSet);
        Operator name = ((DataSet) graph.run(new EdgeSourceDegrees().setParallelism(this.parallelism))).map(new ExtractSourceDegree()).setParallelism(this.parallelism).name("Extract source degree");
        Operator name2 = dataSet.flatMap(new InitializeSourceVertices()).setParallelism(this.parallelism).name("Initialize source vertex scores");
        IterativeDataSet parallelism = dataSet.map(new InitializeVertexScores()).withBroadcastSet(count, VERTEX_COUNT).setParallelism(this.parallelism).name("Initialize scores").iterate(this.maxIterations).setParallelism(this.parallelism);
        Operator name3 = parallelism.coGroup(name).where(new int[]{0}).equalTo(new int[]{0}).with(new SendScore()).setParallelism(this.parallelism).name("Send score").groupBy(new int[]{0}).reduce(new Functions.SumScore()).setCombineHint(ReduceOperatorBase.CombineHint.HASH).setParallelism(this.parallelism).name("Sum");
        Operator name4 = name3.union(name2).setParallelism(this.parallelism).name("Union with source vertices").map(new AdjustScores(this.dampingFactor)).withBroadcastSet(name3.reduce(new SumVertexScores()).setParallelism(this.parallelism).name("Sum"), SUM_OF_SCORES).withBroadcastSet(count, VERTEX_COUNT).setParallelism(this.parallelism).name("Adjust scores");
        if (this.convergenceThreshold < Double.MAX_VALUE) {
            operator = parallelism.join(name4).where(new int[]{0}).equalTo(new int[]{0}).with(new ChangeInScores()).setParallelism(this.parallelism).name("Change in scores");
            parallelism.registerAggregationConvergenceCriterion(CHANGE_IN_SCORES, new DoubleSumAggregator(), new ScoreConvergence(this.convergenceThreshold));
        } else {
            operator = name4;
        }
        return parallelism.closeWith(operator).map(new TranslateResult()).setParallelism(this.parallelism).name("Map result");
    }
}
