栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

pyspark --- 将df按照某一列展开

pyspark --- 将df按照某一列展开

from pyspark import SparkContext, SQLContext, SparkConf
from pyspark.sql import SparkSession
import warnings
from pyspark.sql import functions as fn
from pyspark.sql.types import StructField, LongType

# sc = SparkContext(appName="AAA")
sqlContext = SQLContext(sc)
ss = SparkSession(sc).builder 
    .config('spark.sql.shuffle.partitions', 2000) 
    .config('spark.executor.memoryOverhead', 8192) 
    .config('spark.driver.memoryOverhead', 8192) 
    .config('spark.dynamicAllocation.enabled', 'true') 
    .getOrCreate()

warnings.filterwarnings("ignore", category=DeprecationWarning)
sc.setLogLevel('ERROR')     
data = ss.createDataframe([
    {'A': 'X', 'B': 'X', 'C': 3}, {'A': 'Y', 'B': 'Y', 'C': 1}, {'A': 'L', 'B': 'L', 'C': 1}
])
data.show()

def row_dealwith(row):
    a, b, c = row[0], row[1], row[2]
    if c == 1:
        return (a, b, c)
    resA, resB, resC = [], [], []
    for i in range(c):
        resA.append(a)
        resB.append(b)        
        resC.append(1)
    return (','.join(str(i) for i in resA), ','.join(str(i) for i in resB), ','.join(str(i) for i in resC),)

newdata = data.rdd.map(row_dealwith).toDF(schema=['A', 'B', 'C'])
newdata.show()

dfA = newdata.withColumn('A', fn.explode(fn.split(newdata.A, ','))).select('A')
dfB = newdata.withColumn('B', fn.explode(fn.split(newdata.B, ','))).select('B')
dfC = newdata.withColumn('C', fn.explode(fn.split(newdata.C, ','))).select('C')

dfA = mkdf_tojoin(dfA, ss)
dfB = mkdf_tojoin(dfB, ss)
dfC = mkdf_tojoin(dfC, ss)

dfres = dfA.join(dfB, on=['tmpid'], how='left').join(dfC, on=['tmpid'], how='left').drop('tmpid')
dfres.show()

def flat(l):
    for k in l:
        if not isinstance(k, (list, tuple)):
            yield k
        else:
            yield from flat(k)
            
def mkdf_tojoin(df, ss):
    schema = df.schema.add(StructField("tmpid", LongType()))
    rdd = df.rdd.zipWithIndex()
    rdd = rdd.map(lambda x: list(flat(x)))
    df = ss.createDataframe(rdd, schema)
    return df
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/632728.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号