首頁 > 軟體

基於numpy實現邏輯迴歸

2022-07-30 14:01:04

本文範例為大家分享了基於numpy實現邏輯迴歸的具體程式碼,供大家參考,具體內容如下

交叉熵損失函數;sigmoid激勵函數
基於numpy的邏輯迴歸的程式如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets.samples_generator import make_classification

class logistic_regression():
    def __init__(self):
        pass
    def sigmoid(self, x):
        z = 1 /(1 + np.exp(-x))
        return z

    def initialize_params(self, dims):
        W = np.zeros((dims, 1))
        b = 0
        return W, b

    def logistic(self, X, y, W, b):
        num_train = X.shape[0]
        num_feature = X.shape[1]
        a = self.sigmoid(np.dot(X, W) + b)
        cost = -1 / num_train * np.sum(y * np.log(a) + (1 - y) * np.log(1 - a))
        dW = np.dot(X.T, (a - y)) / num_train
        db = np.sum(a - y) / num_train
        cost = np.squeeze(cost)#[]列向量,易於plot
        return a, cost, dW, db

    def logistic_train(self, X, y, learning_rate, epochs):
        W, b = self.initialize_params(X.shape[1])
        cost_list = []
        for i in range(epochs):
            a, cost, dW, db = self.logistic(X, y, W, b)
            W = W - learning_rate * dW
            b = b - learning_rate * db
            if i % 100 == 0:
                cost_list.append(cost)
            if i % 100 == 0:
                print('epoch %d cost %f' % (i, cost))
        params = {
            'W': W,
            'b': b
        }
        grads = {
            'dW': dW,
            'db': db
        }
        return cost_list, params, grads

    def predict(self, X, params):
        y_prediction = self.sigmoid(np.dot(X, params['W']) + params['b'])
        #二分類
        for i in range(len(y_prediction)):
            if y_prediction[i] > 0.5:
                y_prediction[i] = 1
            else:
                y_prediction[i] = 0
        return y_prediction

    #精確度計算
    def accuracy(self, y_test, y_pred):
        correct_count = 0
        for i in range(len(y_test)):
            for j in range(len(y_pred)):
                if y_test[i] == y_pred[j] and i == j:
                    correct_count += 1
        accuracy_score = correct_count / len(y_test)
        return accuracy_score

    #建立資料
    def create_data(self):
        X, labels = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2)
        labels = labels.reshape((-1, 1))
        offset = int(X.shape[0] * 0.9)
        #訓練集與測試集的劃分
        X_train, y_train = X[:offset], labels[:offset]
        X_test, y_test = X[offset:], labels[offset:]
        return X_train, y_train, X_test, y_test

    #畫圖函數
    def plot_logistic(self, X_train, y_train, params):
        n = X_train.shape[0]
        xcord1 = []
        ycord1 = []
        xcord2 = []
        ycord2 = []
        for i in range(n):
            if y_train[i] == 1:#1類
                xcord1.append(X_train[i][0])
                ycord1.append(X_train[i][1])
            else:#0類
                xcord2.append(X_train[i][0])
                ycord2.append(X_train[i][1])
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.scatter(xcord1, ycord1, s=32, c='red')
        ax.scatter(xcord2, ycord2, s=32, c='green')#畫點
        x = np.arange(-1.5, 3, 0.1)
        y = (-params['b'] - params['W'][0] * x) / params['W'][1]#畫二分類直線
        ax.plot(x, y)
        plt.xlabel('X1')
        plt.ylabel('X2')
        plt.show()


if __name__ == "__main__":
    model = logistic_regression()
    X_train, y_train, X_test, y_test = model.create_data()
    print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
    # (90, 2)(90, 1)(10, 2)(10, 1)
    #訓練模型
    cost_list, params, grads = model.logistic_train(X_train, y_train, 0.01, 1000)
    print(params)
    #計算精確度
    y_train_pred = model.predict(X_train, params)
    accuracy_score_train = model.accuracy(y_train, y_train_pred)
    print('train accuracy is:', accuracy_score_train)
    y_test_pred = model.predict(X_test, params)
    accuracy_score_test = model.accuracy(y_test, y_test_pred)
    print('test accuracy is:', accuracy_score_test)
    model.plot_logistic(X_train, y_train, params)

結果如下所示:

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援it145.com。


IT145.com E-mail:sddin#qq.com