package org.apache.flink.table.runtime.operators.aggregate; import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; import org.apache.flink.table.runtime.generated.GeneratedTableAggsHandleFunction; import org.apache.flink.table.runtime.generated.TableAggsHandleFunction; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.util.Collector; import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; public class GroupTableAggFunction extends KeyedProcessFunction{ private static final long serialVersionUID = 1L; private final GeneratedTableAggsHandleFunction genAggsHandler; private final LogicalType[] accTypes; private final RecordCounter recordCounter; private final boolean generateUpdateBefore; private final long stateRetentionTime; // function used to handle all table aggregates private transient TableAggsHandleFunction function = null; // stores the accumulators private transient ValueState accState = null; public GroupTableAggFunction( GeneratedTableAggsHandleFunction genAggsHandler, LogicalType[] accTypes, int indexOfCountStar, boolean generateUpdateBefore, long stateRetentionTime) { this.genAggsHandler = genAggsHandler; this.accTypes = accTypes; this.recordCounter = RecordCounter.of(indexOfCountStar); this.generateUpdateBefore = generateUpdateBefore; this.stateRetentionTime = stateRetentionTime; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); // instantiate function StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); InternalTypeInfo accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); if (ttlConfig.isEnabled()) { accDesc.enableTimeToLive(ttlConfig); } accState = getRuntimeContext().getState(accDesc); } @Override public void processElement(RowData input, Context ctx, Collector out) throws Exception { RowData currentKey = ctx.getCurrentKey(); boolean firstRow; RowData accumulators = accState.value(); if (null == accumulators) { firstRow = true; accumulators = function.createAccumulators(); } else { firstRow = false; } // set accumulators to handler first function.setAccumulators(accumulators); if (!firstRow && generateUpdateBefore) { function.emitValue(out, currentKey, true); } // update aggregate result and set to the newRow if (isAccumulateMsg(input)) { // accumulate input function.accumulate(input); } else { // retract input function.retract(input); } // get accumulator accumulators = function.getAccumulators(); if (!recordCounter.recordCountIsZero(accumulators)) { function.emitValue(out, currentKey, false); // update the state accState.update(accumulators); } else { // and clear all state accState.clear(); // cleanup dataview under current key function.cleanup(); } } @Override public void close() throws Exception { if (function != null) { function.close(); } } }



