栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

禁忌搜索算法求解 TSP 问题的代码示例

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

禁忌搜索算法求解 TSP 问题的代码示例

之前笔者曾经简要介绍过模拟退火算法求解 TSP 问题的代码示例:
模拟退火算法求解 TSP 问题的代码示例

本文以 TSP 问题为例,通过具体代码,说明禁忌搜索算法的迭代过程。


禁忌搜索算法流程图



TSP 问题的代码示例

关于 TSP 问题的介绍从略,算法模块的代码(algorithm.py)如下,注释中已说明算法的迭代过程:

from datetime import datetime
from typing import Tuple, List, Set
import random

import pandas as pd


class TabuSearch(object):
    """
    禁忌搜索
    """
    def __init__(self, num_point: int, mat_dist: List[List[float]],
                 num_iter: int = 100, len_tabu: int = 10, size_neighbour: int = 50):
        """
        禁忌搜索,数据初始化

        :param num_point:  TSP 节点数量
        :param mat_dist:  距离矩阵
        :param num_iter:  迭代次数
        :param len_tabu:  禁忌表长度
        :param size_neighbour:  邻域搜索的次数
        """

        # 问题参数
        self.num_point = num_point
        self.mat_dist = mat_dist

        # 算法参数
        self.num_iter = num_iter
        self.len_tabu = len_tabu
        self.size_neighbour = size_neighbour

        # 结果
        self.route_opt, self.distance_opt = [], None  # 最优路径、最优路径距离
        self.route_res, self.distance_res = [], None  # 结果路径、结果路径距离

    def run(self):
        """
        算法运行
        :return: 无
        """

        dts = datetime.now()
        random.seed(1024)

        # 初始解
        route = self._init_solution()
        obj = self._get_distance(route=route)
        self.route_opt, self.distance_opt = route, obj
        print("初始解: {}".format(self.route_opt))
        print("初始解的路径距离: {}".format(self.distance_opt), 'n')

        # 禁忌表
        list_tabu = []

        # loop 1: 迭代次数 NG
        for i in range(self.num_iter):
            print("当前迭代次数: {}".format(i + 1), 'n')

            # loop 2: 邻域搜索 S
            df_neighbour = pd.Dataframe(columns=["route", "change", "distance"])
            for _ in range(self.size_neighbour):
                route_, set_change = self._create_new_solution(route=route)
                distance = self._get_distance(route=route_)
                tmp_df = pd.Dataframe({"route": [route_], "change": [set_change], "distance": [distance]})
                df_neighbour = df_neighbour.append(tmp_df)
            df_neighbour = df_neighbour.sort_values(by="distance", ascending=True)
            df_neighbour.reset_index(drop=True, inplace=True)

            # 更新解
            for _, df in df_neighbour.iterrows():
                # case 1: 当前邻域最优解可优化全局最优解
                if df["distance"] < self.distance_opt:
                    route = df["route"]
                    self.route_opt, self.distance_opt = route, df["distance"]
                    print("发现可优化全局最优解的新解: {}".format(self.route_opt))
                    print("变化节点: {0},  路径距离: {1}".format(df["change"], self.distance_opt), 'n')
                    list_tabu.append(df["change"])  # 更新禁忌表
                    break

                # case 2: 当前邻域最优解被禁忌
                elif df["change"] in list_tabu:
                    print("当前邻域最优解被禁忌,节点变化: {}".format(df["change"]), 'n')
                    continue

                # case 3: 接收劣解
                else:
                    route = df["route"]
                    print("接收劣解: {}".format(route))
                    print("变化节点: {0},  路径距离: {1}".format(df["change"], df["distance"]), 'n')
                    list_tabu.append(df["change"])  # 更新禁忌表
                    break

            # 禁忌表长度
            if len(list_tabu) > self.len_tabu:
                print("禁忌表长度过大,释放元素: {}".format(list_tabu[: len(list_tabu) - self.len_tabu]), 'n')
                list_tabu = list_tabu[-self.len_tabu:]

        # 运行结果
        self.route_res = route.copy()
        self.distance_res = obj
        print("结果路径: {}".format(self.route_res))
        print("结果路径距离: {}".format(self.distance_res), 'n')

        print("最优路径: {}".format(self.route_opt))
        print("最优路径距离: {}".format(self.distance_opt), 'n')

        dte = datetime.now()
        tm = round((dte - dts).seconds + (dte - dts).microseconds / (10 ** 6), 3)
        print("算法运行时间: {} s".format(tm), 'n')

    def _init_solution(self):
        """
        初始解

        :return: route:  初始路径
        """

        route = [i for i in range(self.num_point)]

        return route

    def _get_distance(self, route: List[int]) -> float:
        """
        计算路径距离
        :param route:  路径
        :return:  距离
        """

        distance = sum(self.mat_dist[route[i]][route[i + 1]] for i in range(len(route) - 1))

        return distance

    def _create_new_solution(self, route: List[int]) -> Tuple[List[int], Set]:
        """
        产生一个新解

        :param route:  当前解

        :return: route_:  生成的新解
        :return: set_change:  交换位置的节点
        """

        route_ = route.copy()

        # 通过随机交换两个位置的方式产生新解
        pos1, pos2 = random.randint(0, self.num_point - 1), random.randint(0, self.num_point - 1)
        while pos1 == pos2:
            pos1, pos2 = random.randint(0, self.num_point - 1), random.randint(0, self.num_point - 1)
        tmp, route_[pos1] = route_[pos1], route_[pos2]
        route_[pos2] = tmp

        # 交换位置的节点
        set_change = {route_[pos1], route_[pos2]}

        return route_, set_change

生成随机算例,并调用算法模块进行求解的主程序代码(main.py)如下:

from datetime import datetime
import math
import random

from algorithm import TabuSearch


dts = datetime.now()


""" 参数 """

# 地点数量
num_point = 20

# 坐标范围、边界宽度
ran_coo = (0, 100)
edge = 1

# 坐标列表、距离矩阵
random.seed(1024)
list_coo = [(random.randint(ran_coo[0] + edge, ran_coo[1] - edge),
             random.randint(ran_coo[0] + edge, ran_coo[1] - edge)) for _ in range(num_point)]
mat_dist = [[math.sqrt((list_coo[i][0] - list_coo[j][0]) ** 2 + (list_coo[i][1] - list_coo[j][1]) ** 2)
             for j in range(num_point)] for i in range(num_point)]

""" 算法 """

num_iter, len_tabu, size_neighbour = 1000, 10, 50
tabu_search = TabuSearch(num_point=num_point, mat_dist=mat_dist,
                         num_iter=num_iter, len_tabu=len_tabu, size_neighbour=size_neighbour)
tabu_search.run()


dte = datetime.now()
tm = round((dte - dts).seconds + (dte - dts).microseconds / (10 ** 6), 3)
print("程序运行总时间: {} s".format(tm), 'n')

参考资料

汪定伟《智能优化方法》

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/314171.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号