import arcpy
from arcpy import PointGeometry, Polyline
from spatial_utility import spatial_utility
import os
from arcpy.sa import *
from arcpy.da import SearchCursor, UpdateCursor, InsertCursor
from sys import exit


arcpy.env.overwriteOutput = True
arcpy.env.addOutputsToMap = 0
ws_path = arcpy.env.scratchGDB
arcpy.env.workspace = ws_path
scratch_folder = arcpy.env.scratchFolder


def main():

    XSs = feature_dev_utility.gen_xs_alt(
        r"C:\Users\kristine.mosuela\OneDrive - Wood PLC\Upshur\Hydraulics\Reference_Data\xs_automation\masterstreamlines_xswidth.shp",
        "xs_width",
        750,
        r"C:\Users\kristine.mosuela\OneDrive - Wood PLC\Upshur\Hydraulics\Reference_Data\xs_drafts_to_initialize_model",
        True
    )         

    return


class feature_dev_utility(object):

    @staticmethod
    def gen_xs(streams: os.PathLike, xs_width_field: str, xs_spacing_distance: int = 500, out_directory: os.PathLike = r"C:\temp\xsautomation", review_clean: bool = True) -> os.PathLike:

        arcpy.AddMessage("generating XS...")

        # check data
        if xs_width_field == "buffer": arcpy.AddError("xs width field cannot be named 'buffer'"); exit()

        # add unique stream id
        streams_clone = arcpy.CopyFeatures_management(streams, "streams_clone")
        spatial_utility.add_unique_ID(streams_clone, "StreamID")
        
        # create flipped streams layer
        flipped_streams = arcpy.CopyFeatures_management(streams_clone, "flipped_streams")
        arcpy.FlipLine_edit(flipped_streams)

        # generate points along flipped streams
        xs_cps = arcpy.GeneratePointsAlongLines_management(flipped_streams, "xs_cps", "DISTANCE", xs_spacing_distance)

        # add unique point id
        spatial_utility.add_unique_ID(xs_cps, "xsID")
        
        # compute buffer
        if len(arcpy.ListFields(xs_cps, "buffer")) > 0:
            arcpy.DeleteField_management(xs_cps, "buffer")
        arcpy.AddField_management(xs_cps, "buffer", "LONG")
        arcpy.CalculateField_management(xs_cps, "buffer", f"!{xs_width_field}!/2")

        # buffer and dissolve using stream id -> clouds
        clouds = arcpy.Buffer_analysis(xs_cps, "clouds", "buffer", "", "", "LIST", ["StreamID"])

        # feature to line on clouds -> linings
        linings = arcpy.FeatureToLine_management(clouds, "linings")

        # feature verts to points both ends on initial streams -> ends
        ends = arcpy.FeatureVerticesToPoints_management(streams_clone, "ends", "BOTH_ENDS")

        # iterate through ends and get distance from end to linings based on id and place value + 1 in buffer_end field
        if len(arcpy.ListFields(ends, "buffer_end")) > 0:
            arcpy.DeleteField_management(ends, "buffer_end")
        arcpy.AddField_management(ends, "buffer_end", "long")
        linings_d = arcpy.Dissolve_management(linings, "linings_d", "StreamID")
        with arcpy.da.UpdateCursor(ends, ["SHAPE@", "StreamID", "buffer_end"]) as uc:
            for end in uc:
                with arcpy.da.SearchCursor(linings_d, ["SHAPE@", "StreamID"], "StreamID = {}".format(str(end[1]))) as sc:
                    for lining in sc:
                        dist = lining[0].distanceTo(end[0])
                        end[2] = dist * 1.3
                        break
                uc.updateRow(end)

        # buffer ends -> erasers
        erasers = arcpy.Buffer_analysis(ends, "erasers", "buffer_end")

        # erase linings with erasers -> links
        links = arcpy.Erase_analysis(linings, erasers, "links")

        # dissolve links based on stream id -> rails
        rails = arcpy.MultipartToSinglepart_management(arcpy.Dissolve_management(links, "rails_multi", "StreamID"), "rails")

        # assign left or right to rails
        spatial_utility.add_unique_ID(rails, "railID")
        beacons = arcpy.FeatureToPoint_management(rails, "beacons", "INSIDE")
        if len(arcpy.ListFields(beacons, "side")) > 0:
            arcpy.DeleteField_management(beacons, "side")
        arcpy.AddField_management(beacons, "side", "TEXT")
        with arcpy.da.UpdateCursor(beacons, ["SHAPE@", "StreamID", "side"]) as uc:
            for beacon in uc:
                with arcpy.da.SearchCursor(streams_clone, ["SHAPE@", "StreamID"], "StreamID = {}".format(str(beacon[1]))) as sc:
                    for stream in sc:
                        info = stream[0].queryPointAndDistance(beacon[0])
                        if info[3]:
                            beacon[2] = "r"
                        else:
                            beacon[2] = "l"
                        break
                uc.updateRow(beacon)
        arcpy.JoinField_management(rails, "railID", beacons, "railID", ["side"])

        # create points to snap to left and right rails and snap
        if len(arcpy.ListFields(xs_cps, "seq")) > 0:
            arcpy.DeleteField_management(xs_cps, "seq")
        arcpy.AddField_management(xs_cps, "seq", "short")
        arcpy.CalculateField_management(xs_cps, "seq", 0)
        lefts = arcpy.CopyFeatures_management(xs_cps, "lefts")
        arcpy.CalculateField_management(lefts, "seq", -1)
        rights = arcpy.CopyFeatures_management(xs_cps, "rights")
        arcpy.CalculateField_management(rights, "seq", 1)

        left_rails = arcpy.Select_analysis(rails, "left_rails", "side = 'l'")
        right_rails = arcpy.Select_analysis(rails, "right_rails", "side = 'r'")

        snap_lefts = spatial_utility.snap_to_line(lefts, "StreamID", left_rails, "StreamID", arcpy.env.workspace, "snap_lefts")
        snap_rights = spatial_utility.snap_to_line(rights, "StreamID", right_rails, "StreamID", arcpy.env.workspace, "snap_rights")

        # # snap pour point using dist of 300ft and convert raster to points -> anodes
        # pp = arcpy.Merge_management([snap_lefts, snap_rights], "pp")
        # spp = SnapPourPoint(pp, flowacc, "300 Feet")
        # anodes = arcpy.RasterToPoint_conversion(spp, "anodes")

        # # snap left and right points to anodes using 300ft dist
        # arcpy.Snap_edit(snap_lefts, [[anodes, "VERTEX", "300 Feet"]])
        # arcpy.Snap_edit(snap_rights, [[anodes, "VERTEX", "300 Feet"]])

        # points to lines -> XS
        xs_verts = arcpy.Merge_management([xs_cps, snap_lefts, snap_rights], "xs_verts")
        xs_lines = arcpy.PointsToLine_management(xs_verts, "xs_lines", "xsID", "seq")
        XS_j = arcpy.SpatialJoin_analysis(xs_lines, xs_cps, "XS_j", "JOIN_ONE_TO_ONE", "KEEP_ALL", "", "INTERSECT")
        arcpy.AddMessage("initial XS generated. reviewing and cleaning up...")

        # snip ends
        snips = arcpy.Buffer_analysis(arcpy.FeatureVerticesToPoints_management(XS_j, "tips", "BOTH_ENDS"), "snips", 1, "", "", "ALL")
        XS_snipped = arcpy.Erase_analysis(XS_j, snips, "XS_snipped")

        if review_clean:
            out_XS_path = feature_dev_utility.clean_xs(streams_clone, XS_snipped, out_directory)
        else:
            out_XSs = arcpy.CopyFeatures_management(XS_snipped, os.path.join(out_directory, "xs_initial"))
            out_XS_path = arcpy.Describe(out_XSs).catalogPath

        arcpy.AddMessage("XS generation complete")
        return out_XS_path


    @staticmethod
    def clean_xs(streams: os.PathLike, XS: os.PathLike, out_directory: os.PathLike = r"C:\temp\xsautomation") -> os.PathLike:

        arcpy.AddMessage("cleaning XSs...")

        streams_clone = arcpy.CreateUniqueName("streams_clone", arcpy.env.workspace)
        streams_clone = arcpy.CopyFeatures_management(streams, streams_clone)
        xs_clone = arcpy.CopyFeatures_management(arcpy.SelectLayerByLocation_management(XS, "INTERSECT", streams_clone, "", "NEW_SELECTION"), "xs_clone")

        # check for multiple stream ints
        spatial_utility.add_unique_ID(xs_clone, "xsID")
        ints_multi = arcpy.Intersect_analysis([xs_clone, streams_clone], "ints_multi", "", "", "POINT")
        ints = arcpy.MultipartToSinglepart_management(ints_multi, "ints")
        if len(arcpy.ListFields(ints, "int_cnt")) > 0:
            arcpy.DeleteField_management(ints, "int_cnt")
        arcpy.AddField_management(ints, "int_cnt", "short")
        arcpy.CalculateField_management(ints, "int_cnt", 1)
        ints_diss = arcpy.Dissolve_management(ints, "ints_diss", ["xsID"], [["int_cnt", "SUM"]])
        arcpy.JoinField_management(xs_clone, "xsID", ints_diss, "xsID", "SUM_int_cnt")
        try:
            XS_cleaned = arcpy.Select_analysis(xs_clone, "XS_cleaned", "SUM_int_cnt = 1")
        except:
            XS_cleaned = xs_clone

        # check xs ints per stream and delete XSs that intersect another
        prev_xs = None
        sql_clause = (None, f"""ORDER BY StreamID ASC; ORDER BY Station DESC""")
        with UpdateCursor(XS_cleaned, ["SHAPE@", "StreamID"], sql_clause=sql_clause) as uc:
            for xs in uc:
                if prev_xs:
                    if xs[0].crosses(prev_xs) or xs[0].overlaps(prev_xs):
                        uc.deleteRow()
                    else:
                        prev_xs = xs[0]
                else:
                    prev_xs = xs[0]

        out_cleaned_xs_path = os.path.join(out_directory, "xs_cleaned")
        arcpy.CopyFeatures_management(XS_cleaned, out_cleaned_xs_path)
        arcpy.AddMessage("XS cleaning complete")
        return out_cleaned_xs_path


    @staticmethod
    def gen_xs_alt(streams: os.PathLike, xs_width_field: str, xs_spacing_distance: int = 500, out_directory: os.PathLike = r"C:\temp\xsautomation", review_clean: bool = True) -> os.PathLike:

        arcpy.AddMessage("generating XS...")

        # check data
        if xs_width_field == "buffer": arcpy.AddError("xs width field cannot be named 'buffer'"); exit()

        # add unique stream id
        streams_clone = arcpy.CopyFeatures_management(streams, "streams_clone")
        spatial_utility.add_unique_ID(streams_clone, "StreamID")
        
        # create flipped streams layer
        flipped_streams = arcpy.CopyFeatures_management(streams_clone, "flipped_streams")
        arcpy.FlipLine_edit(flipped_streams)

        # generate points along flipped streams
        xs_cps = arcpy.GeneratePointsAlongLines_management(flipped_streams, "xs_cps", "DISTANCE", xs_spacing_distance)

        # add unique point id
        spatial_utility.add_unique_ID(xs_cps, "xsID")

        # create dictionary with xs/stream intersection points as {"StreamID" : [[PointGeometry for left side use, PointGeometry for right side use, xsID], ...]}
        xs_pnts_dict = {}
        with SearchCursor(xs_cps, ["SHAPE@", "StreamID", "xsID"]) as sc:
            for pnt in sc:
                xs_pnts_dict.setdefault(pnt[1], []).append([pnt[0], pnt[0], pnt[2]])

        # create and populate the XS vertices
        left_id = 1000
        right_id = 1000
        # buff_dists = list(map(lambda x: int((2**x)/x), range(5,22)))
        buff_dists = list(map(lambda x: x*100, range(1,501)))
        xs_verts = arcpy.CreateFeatureclass_management(ws_path, "xs_verts", "POINT", spatial_reference=arcpy.Describe(streams_clone).spatialReference)
        arcpy.AddField_management(xs_verts, "StreamID", "LONG")
        arcpy.AddField_management(xs_verts, "xsID", "LONG")
        arcpy.AddField_management(xs_verts, "vert_order", "LONG")
        stream_count = arcpy.GetCount_management(streams_clone)[0]
        stream_counter = 0
        with SearchCursor(streams_clone, ["SHAPE@", "StreamID", xs_width_field]) as sc, InsertCursor(xs_verts, ["SHAPE@", "StreamID", "xsID", "vert_order"]) as ic:
            for feat in sc:
                stream_counter += 1
                arcpy.AddMessage(f"building XSs for stream {stream_counter} of {stream_count}...")
                xs_pnts_data = xs_pnts_dict[feat[1]]
                for buff_dist in buff_dists:
                    arcpy.AddMessage(f"...using buffer distance {buff_dist}...")
                    left_id -= 1
                    right_id += 1
                    line_buff_bound = feat[0].buffer(buff_dist).boundary()
                    start_buff = PointGeometry(feat[0].firstPoint).buffer(buff_dist+1)
                    end_buff = PointGeometry(feat[0].lastPoint).buffer(buff_dist+1)
                    buff_bound_erased = line_buff_bound.difference(start_buff).difference(end_buff)
                    if feat[0].queryPointAndDistance(Polyline(buff_bound_erased[0]).centroid)[3]:
                        right_line = Polyline(buff_bound_erased[0])
                        left_line = Polyline(buff_bound_erased[1])
                    else:
                        left_line = Polyline(buff_bound_erased[0])
                        right_line = Polyline(buff_bound_erased[1])
                    for row in xs_pnts_data:
                        left_vert_geom = left_line.snapToLine(row[0])
                        angle, dist = left_vert_geom.angleAndDistanceTo(row[0], "PLANAR")
                        left_vert_geom = left_vert_geom.pointFromAngleAndDistance(angle, dist*0.5, "PLANAR")
                        row[0] = left_vert_geom
                        ic.insertRow([left_vert_geom, feat[1], row[2], left_id])
                        right_vert_geom = right_line.snapToLine(row[1])
                        angle, dist = right_vert_geom.angleAndDistanceTo(row[1], "PLANAR")
                        right_vert_geom = right_vert_geom.pointFromAngleAndDistance(angle, dist*0.5, "PLANAR")
                        row[1] = right_vert_geom
                        ic.insertRow([right_vert_geom, feat[1], row[2], right_id])
                    if buff_dist > feat[2]/2:
                        break

        # points to lines -> XS
        xs_lines = arcpy.PointsToLine_management(xs_verts, "xs_lines", "xsID", "vert_order")

        # add stationings
        if len(arcpy.ListFields(xs_lines, "Station")) > 0:
            arcpy.DeleteField_management(xs_lines, "Station")
        xs_pnts = arcpy.MultipartToSinglepart_management(arcpy.Intersect_analysis([xs_lines, streams_clone], "int", "", "", "POINT"), "xs_pnts")
        spatial_utility.add_stationings(xs_pnts, "StreamID", streams_clone, "StreamID")
        arcpy.JoinField_management(xs_lines, "xsID", xs_pnts, "xsID", ["Station", "StreamID"])
        arcpy.CalculateField_management(xs_lines, "Station", "round(!Station!, 2)")

        out_XSs = arcpy.CopyFeatures_management(xs_lines, os.path.join(out_directory, "xs_initial"))

        if review_clean:
            out_XS_path = feature_dev_utility.clean_xs(streams_clone, xs_lines, out_directory)
        else:
            out_XS_path = arcpy.Describe(out_XSs).catalogPath

        arcpy.AddMessage("XS generation complete")
        return out_XS_path


if __name__ == "__main__":
    main()