import arcpy
import os
import json
import time

arcpy.env.overwriteOutput = True
arcpy.env.addOutputsToMap = 0
arcpy.env.workspace = arcpy.env.scratchGDB

def test():
    fc = r"C:\Users\jacob.bates\Desktop\ble_hydra\_test_data\upper_AL\StreamNetwork\Scope.shp"

    a = time.time()
    spatial_utility.get_feature_coordinates(fc, arcpy.Describe(fc).OIDFieldName)
    b = time.time()
    print(b - a)
    
    a = time.time()
    spatial_utility.get_feature_coordinates_alt(fc, arcpy.Describe(fc).OIDFieldName)
    b = time.time()
    print(b - a)

    return

class spatial_utility(object):

    @staticmethod
    def transfer_field_by_spatial_association(src_lyr: str = None, dest_lyr: str = None, field_name: str = None):

        if len(arcpy.ListFields(dest_lyr, field_name)) > 0:
            arcpy.DeleteField_management(dest_lyr, field_name)

        spatial_utility.add_unique_ID(dest_lyr, "tmp_ID")

        fms = arcpy.FieldMappings()
        fms.addTable(dest_lyr)
        fms.addTable(src_lyr)
        field_index = fms.findFieldMapIndex(field_name)
        fieldmap = fms.getFieldMap(field_index)
        fieldmap.mergeRule = "first"
        fms.replaceFieldMap(field_index, fieldmap)
        sptl_jn = arcpy.SpatialJoin_analysis(dest_lyr, src_lyr, "sptl_jn", "JOIN_ONE_TO_ONE", "KEEP_ALL", fms, "INTERSECT", "1 Feet")

        arcpy.JoinField_management(dest_lyr, "tmp_ID", sptl_jn, "tmp_ID", field_name)

        arcpy.DeleteField_management(dest_lyr, "tmp_ID")

    @staticmethod
    def insert_features(new_features: str = None, existing_features: str = None, query_field: str = None) -> None:
        spatial_utility.remove_feature_from_map(existing_features)
        lookup = []
        with arcpy.da.SearchCursor(new_features, [query_field]) as sc:
            for row in sc:
                if row[0] not in lookup:
                    lookup.append(row[0])
    
        query = f"{query_field} NOT IN ({str(lookup).strip('[]')})"
        try:
            picked = arcpy.Select_analysis(existing_features, "picked", query)
        except:
            picked = arcpy.CopyFeatures_management(existing_features, "picked")
            # arcpy.DeleteFeatures_management(picked)
        arcpy.Append_management([new_features], picked, "NO_TEST")
        merged = picked
        arcpy.Delete_management(existing_features)
        arcpy.CopyFeatures_management(merged, existing_features)
        spatial_utility.add_feature_to_map(existing_features)

    @staticmethod
    def remove_feature_from_map(layer: str = None) -> None:
        try:
            layer_path = arcpy.Describe(layer).catalogPath
            map = arcpy.mp.ArcGISProject("CURRENT").activeMap
            map.removeLayer(os.path.basename(layer_path))
        except:
            pass

    @staticmethod
    def add_feature_to_map(layer: str = None) -> None:
        try:
            layer_path = arcpy.Describe(layer).catalogPath
            map = arcpy.mp.ArcGISProject("CURRENT").activeMap
            map.addDataFromPath(layer_path)
        except:
            pass

    @staticmethod
    def add_lengths(features: str = None, field_name: str = None) -> None:
        if len(arcpy.ListFields(features, field_name)) > 0:
            arcpy.DeleteField_management(features, field_name)
        arcpy.AddField_management(features, field_name, "DOUBLE")
        arcpy.CalculateGeometryAttributes_management(features, [[field_name, "LENGTH"]], "FEET_US")

    @staticmethod
    def print_fields(features: str = None) -> None:
        print(str([f.name for f in arcpy.ListFields(features)]))

    # get array of xy coordinates and return dict as {id : [(x1, y1), (x2, y2), ...]}
    @staticmethod
    def get_feature_coordinates(feature_class, id_field, rounding = 3):
        coordinate_dict = {}
        with arcpy.da.SearchCursor(feature_class, [id_field, "SHAPE@"]) as sc:
            for feature in sc:
                coordinate_dict[feature[0]] = []
                vertices = feature[1].getPart(0)
                for vertex in vertices:
                    coordinate_dict[feature[0]].append((abs(round(vertex.X, rounding)), abs(round(vertex.Y, rounding))))
        return coordinate_dict

    # SLOWER THAN ABOVE METHOD
    @staticmethod
    def get_feature_coordinates_alt(feature_class, id_field):
        coordinate_dict = {}

        f_path = os.path.join(arcpy.env.scratchFolder, "features.geojson")
        if arcpy.Exists(f_path):
            arcpy.Delete_management(f_path)
        
        gj_file = arcpy.FeaturesToJSON_conversion(feature_class, f_path, "", "", "", "GEOJSON")
        gj_file = arcpy.Describe(gj_file).catalogPath
        with open(gj_file, "r") as reader:
            gj_str = reader.read()
        json_obj = json.loads(gj_str)

        for feat in json_obj["features"]:
            coordinate_dict[feat["properties"][id_field]] = feat["geometry"]["coordinates"]

        return coordinate_dict

    # filter vertices
    @staticmethod
    def filter_verts(feature_class, max_vertex_count=450):

        with arcpy.da.UpdateCursor(feature_class, "SHAPE@") as uc:
            for feature in uc:
                tolerance = 0.1
                attempts = 0
                vert_cnt = feature[0].pointCount
                while vert_cnt > max_vertex_count and attempts < 10:
                    feature[0] = feature[0].generalize(tolerance)
                    vert_cnt = feature[0].pointCount
                    tolerance += 0.1
                    attempts += 1
                uc.updateRow(feature)

    # extract terrain profile for lines, add to list, and filter terrain profile list to under 500 points. Return dict as {id : [(station1, elevation1),(station2, elevation2), ...]}
    @staticmethod
    def get_terrain_profiles(line_feature_class, id_field, dem, ws_path = arcpy.env.scratchGDB, rounding = 3):
        if not ws_path: ws_path = arcpy.env.scratchGDB
        result = arcpy.GetCount_management(line_feature_class)
        count = int(result.getOutput(0))

        line_feature_class_clone = arcpy.CopyFeatures_management(line_feature_class, "line_feature_class_clone")

        # create lookup to relate line id to id_field
        id_lookup_dict = {}
        with arcpy.da.SearchCursor(line_feature_class_clone, ["OID@", id_field]) as lines:
            for line in lines:
                id_lookup_dict[line[0]] = line[1]

        # get terrain profiles for lines
        profile_dict = {}
        profileTable = arcpy.CreateUniqueName("profile_table", ws_path)
        arcpy.Densify_edit(line_feature_class_clone, "", 3)
        arcpy.ddd.StackProfile(line_feature_class_clone, dem, profileTable)
        spatial_ref = arcpy.Describe(line_feature_class_clone).spatialReference
        arcpy.management.CreateFeatureclass(
            ws_path, "profiles", "POLYLINE", "", "", "", spatial_ref
        )
        profiles = os.path.join(ws_path, "profiles")
        if len(arcpy.ListFields(profiles, id_field)) > 0:
            arcpy.DeleteField_management(profiles, id_field)
        arcpy.management.AddField(profiles, id_field, "DOUBLE")
        with arcpy.da.InsertCursor(
            profiles, [id_field, "SHAPE@"]
        ) as ic, arcpy.da.SearchCursor(
            profileTable, ["LINE_ID", "FIRST_DIST", "FIRST_Z"]
        ) as sc:
            for row in sc:
                if id_lookup_dict[row[0]] not in profile_dict:
                    profile_dict[id_lookup_dict[row[0]]] = []
                profile_dict[id_lookup_dict[row[0]]].append((row[1], row[2]))
            for k, v in profile_dict.items():
                p = arcpy.Polyline(arcpy.Array(arcpy.Point(*coords) for coords in v))
                ic.insertRow([k, p])

        arcpy.Generalize_edit(profiles, "0.1 Feet")
        spatial_utility.filter_verts(profiles)

        profile_dict = spatial_utility.get_feature_coordinates(profiles, id_field, rounding)

        return profile_dict

    # snap points to line based on ID
    @staticmethod
    def snap_to_line(
        point_feature_class, point_id_field, line_feature_class, line_id_field, ws_path = arcpy.env.scratchGDB, output_points_name = "snapped_points"
    ):

        arcpy.CreateFeatureclass_management(
            ws_path, output_points_name, "POINT", point_feature_class
        )
        snapped_points = os.path.join(ws_path, output_points_name)

        point_slice = "point_slice"
        line_slice = "line_slice"
        line_IDs = []
        with arcpy.da.SearchCursor(line_feature_class, ["shape@", line_id_field]) as sc:
            for row in sc:
                if row[1] not in line_IDs:
                    line_IDs.append(row[1])
        for i in line_IDs:
            arcpy.Select_analysis(
                line_feature_class,
                line_slice,
                line_id_field + " = " + str(i),
            )
            arcpy.Select_analysis(
                point_feature_class,
                point_slice,
                point_id_field + " = " + str(i),
            )
            arcpy.Snap_edit(point_slice, [[line_slice, "EDGE", 100000]])
            arcpy.Append_management(point_slice, snapped_points)

        return snapped_points

    # add point stationings along lines by linear referencing
    @staticmethod
    def add_stationings(
        point_feature_class, point_id_field, line_feature_class, line_id_field, station_field_name = "Station", measure_from_start_of_line: bool = False
    ):

        if len(arcpy.ListFields(point_feature_class, station_field_name)) > 0:
            arcpy.DeleteField_management(point_feature_class, station_field_name)
        arcpy.AddField_management(point_feature_class, station_field_name, "DOUBLE")

        with arcpy.da.SearchCursor(line_feature_class, ["shape@", line_id_field]) as sc:
            for line in sc:
                with arcpy.da.UpdateCursor(
                    point_feature_class,
                    ["shape@", station_field_name],
                    "{0} = {1}".format(point_id_field, line[1]),
                ) as uc:
                    if measure_from_start_of_line:
                        for point in uc:
                            point[1] = line[0].measureOnLine(point[0])
                            uc.updateRow(point)
                    else:
                        for point in uc:
                            point[1] = line[0].length - line[0].measureOnLine(point[0])
                            uc.updateRow(point)

    @staticmethod
    def add_unique_ID(in_features, ID_field_name):
        if len(arcpy.ListFields(in_features, ID_field_name)) > 0:
            arcpy.DeleteField_management(in_features, ID_field_name)
        arcpy.AddField_management(in_features, ID_field_name, "LONG")
        arcpy.CalculateField_management(
            in_features,
            ID_field_name,
            "!{}!".format(arcpy.Describe(in_features).OIDFieldName),
            "PYTHON",
        )


if __name__ == "__main__":
    test()