神经网络初始化实例化的维度与调用输入数据的维度

神经网络初始化实例化的维度与调用输入数据的维度

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

#from agents.helpers import SinusoidalPosEmb
class SinusoidalPosEmb(nn.Module):
def init(self, dim=16): #dim为初始化需要设置的参数 比如默认为16 计算后会升维
super().init()
self.dim = dim

def forward(self, x):
    device = x.device
    half_dim = self.dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = x[:, None] * emb[None, :]
    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
    return emb

class MLP(nn.Module):
“”"
MLP Model
“”"
def init(self, ##初始化以及参数
state_dim=2, #####初始化实例化定义的维度这个对看懂代码很关键! 与后面调用函数的输入数据的维度一般需要一致
action_dim=2,
device=“cpu”,
t_dim=16):

    super(MLP, self).__init__()
    self.device = "cpu"

    self.time_mlp = nn.Sequential(
        SinusoidalPosEmb(t_dim),  #  这个是初始化为16维度
        nn.Linear(t_dim, t_dim * 2),
        nn.Mish(),
        nn.Linear(t_dim * 2, t_dim),
    )

    input_dim = state_dim + action_dim + t_dim ## 2+2+16=20  维度数量
    self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
                                   nn.Mish(),
                                   nn.Linear(256, 256),
                                   nn.Mish(),
                                   nn.Linear(256, 256),
                                   nn.Mish())

    self.final_layer = nn.Linear(256, action_dim)  ##输出2维度

def forward(self, x, time, state):  ##定义个方法

    t = self.time_mlp(time)
    x = torch.cat([x, t, state], dim=1)  ###第二个维度以后要一致
    x = self.mid_layer(x)

    return self.final_layer(x)

MLPinstance=MLP()#初始化一个实例
MLPinstance

######################
x = torch.rand(5, 1, 2) # [10, 1, 8] #bath_size,1,2维度
time=5 #标量不对
state=torch.rand(5, 1, 2)
x.shape
torch.tensor(range(5)).unsqueeze(1).unsqueeze(2).shape,torch.rand(5, 1, 2).shape #有什么区别?

这两个 torch.tensor 的操作创建了不同形状和内容的张量。

1. torch.tensor(range(5)).unsqueeze(1).unsqueeze(2)

- 首先,torch.tensor(range(5)) 创建了一个一维张量,内容为 [0, 1, 2, 3, 4]

- 然后,.unsqueeze(1) 在第1个维度(现在是一维张量的唯一维度,等同于插入一个新的列维度)添加一个维度,使得形状变为 (5, 1)

- 接着,.unsqueeze(2) 再次在新维度的后面添加一个维度,最终形状变为 (5, 1, 1)。因此,这个操作的结果是一个形状为 (5, 1, 1) 的张量,每个元素都是从0到4的数字,每一行重复同一个数字,并且在最后两个维度上只有一个单位。

2. torch.rand(5, 1, 2)

- 这个操作直接创建了一个形状为 (5, 1, 2) 的三维张量,其中的所有元素都是从0到1之间的随机数(均匀分布)。这意味着你得到的是一个有5行,每行包含一个大小为1的子列表,每个子列表内有2个随机数的张量。

总结:

- 形状不同:前者形状为 (5, 1, 1),主要由连续的整数构成;后者形状为 (5, 1, 2),由随机浮点数构成。

- 内容不同:前者的内容是确定的,是0到4的整数,每个数字沿最后一个维度重复;后者的内容是随机的,范围在0到1之间。

- 数据类型不同:默认情况下,前者(基于 range)会是整数类型(除非显式转换),而后者明确是浮点数类型,因为使用了 torch.rand

torch.tensor(range(5)).unsqueeze(1).unsqueeze(1).shape在那个维度后面升维度,torch.tensor(range(5)).unsqueeze(1).unsqueeze(2).shape

######################################

x = torch.rand(2, 1) # [10, 1, 8]
time=torch.tensor([5]) ##一维度的可以
state=torch.rand(2, 1)
x.shape

实例化SinusoidalPosEmb类

pos_emb = SinusoidalPosEmb(dim=16)

创建一个示例输入张量x

假设我们有一个序列长度为5,维度为16的输入

#x = torch.tensor(range(5)).unsqueeze(1).unsqueeze(2) # 形状为 [5, 1] bath size,以及对应的timestep
time = torch.tensor(time)
print(time.shape)

调用forward方法

positional_embedding = pos_emb(time)

##########################################################################

x = torch.rand(5, 1) # [10, 1, 8]
time= torch.tensor(range(5)).unsqueeze(1)#没必要这样 经过embedding后会增加一个维度
state=torch.rand(5, 1)
x.shape

#################################################################
x = torch.rand(5, 1)
time= torch.tensor([5])
state=torch.rand(5, 1)
x.shape,time.shape
#############################################成功调用的####################
x = torch.rand(5, 1) ####一般定义了2维度 所以一般输入的数据就是2个维度的多个元素 当然少于2维的运算后要能够升维度 或者运算后能够降低倒定义初始化中需要的的维度!!!!!!!!
time= torch.tensor(range(5))
state=torch.rand(5, 1)
x.shape,time.shape

torch.tensor([5]) 和torch.tensor(range(5))的维度有什么区别

torch.tensor([5]) 创建的是一个形状为 torch.Size([1]) 的张量,表示它是一个包含单个元素的一维张量。

torch.tensor(range(5)) 创建的是一个形状为 torch.Size([5]) 的张量,表示它是一个包含5个元素的一维张量。

MLPinstance(x,time,state) #成功调用

将位置索引 x 转换为形状 (5, 1),频率向量转换为形状 (1, 8) 是怎么理解请举例
当我们谈论将位置索引 x 转换为形状 (5, 1) 和频率向量转换为形状 (1, 8),我们实际上是在讨论在进行矩阵运算之前对张量(在PyTorch中,张量是多维数组)的形状调整,以便它们能够进行有效的点乘操作。这个过程通常称为“广播”(broadcasting),它允许不同形状的张量进行数学运算,只要它们在没有明确指定的维度上大小为1或者完全匹配。

例子说明:

位置索引 x

原始的位置索引 x 是一个一维张量,表示5个不同的位置:

x = torch.tensor([0, 1, 2, 3, 4])

形状是 (5,),表示有5个元素。

为了使其能与频率向量正确点乘,我们需要将其扩展成一个二维张量,形状变为 (5, 1)。这意味着每个位置现在被视为一个单独的行,每一行只有一个元素:

x_expanded = x.unsqueeze(1)
# x_expanded 的形状现在是 (5, 1),内容为:
# tensor([[0],
#         [1],
#         [2],
#         [3],
#         [4]])

通过 unsqueeze(1) 操作,我们在索引1的位置增加了一个新的维度,使得每个位置值变成一个单独的列向量。

频率向量

频率向量是基于一半的维度数(如果 dim=16,则 half_dim=8)创建的,假定我们已经计算了对应的频率值并存储在一个一维张量中:

freqs = torch.tensor([...])  # 假设这里填入了计算得到的频率值序列

其原始形状可能是 (8,),代表了8个不同频率的值。

为了点乘,我们需要将其调整为 (1, 8) 形状,意味着它变成了一个只有一行的矩阵:

freqs_expanded = freqs.unsqueeze(0)
# freqs_expanded 的形状现在是 (1, 8),内容保持原频率值不变,但结构上作为单行多列

通过 unsqueeze(0),我们在索引0的位置增加了一个新的维度,使频率值构成了一行向量。

结果

完成上述操作后,x_expandedfreqs_expanded 可以直接相乘,因为它们在第二个维度(列)上是匹配的(一个是1,另一个是8,但因为是点乘,一维会自动复制以匹配另一方的维度)。这样,每个位置的索引值会与所有的频率值进行一一对应乘法运算,从而生成最终的编码向量的一部分。

a为(5,3,1)的张量
b为(5,16)的张量 能否把b变为(5,3,16)的维度?
要将形状为 (5, 16) 的张量 b 转换成形状为 (5, 3, 16),以便它可以与形状为 (5, 3, 1) 的张量 a 在某些维度上进行运算,我们需要沿着一个新维度重复 b 中的每一行数据三次(因为 a 的第二个维度是 3)。这可以通过使用 torch.repeat_interleave 函数来完成,具体是沿着第1维度(在PyTorch中,维度索引从0开始计数,所以这是第二个维度)进行重复。

下面是操作的示例代码:

import torch

# 假设 b 是形状为 (5, 16) 的张量
b = torch.randn(5, 16)

# 将 b 调整为形状为 (5, 3, 16)
b_expanded = b.unsqueeze(1).repeat(1, 3, 1)

print("Original shape of b:", b.shape)
print("Expanded shape of b:", b_expanded.shape)

这段代码首先使用 unsqueeze(1)b 添加一个新的维度(变成 (5, 1, 16)),然后使用 repeat(1, 3, 1) 沿着新添加的维度复制数据三次,最终得到形状为 (5, 3, 16) 的张量。

这样调整后,b_expanded 就可以和形状为 (5, 3, 1) 的张量 a 在相应的维度上进行操作了,比如逐元素相乘等运算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/678580.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《python》poetry install下载缓慢,网络问题断开连接--poetry换源镜像下载+国内镜像

在使用打包工具poetry进行打包的是出现了一个问题就是,在使用poetry进行打包的时候出现了,连接断开这样的问题,这个问题是可以通过换源,通过国内的镜像来解决这个问题就可以了。 找到项目中的pyoroject。toml文件这个文件中写了一…

ESP8266 01sWiFi模块保姆级教程 烧录和联网,连接华为云

前言 写在前面。 这个esp01s联网真的是折磨人啊,浪费了我三四天的时间,网上各种教程叫天天不灵,叫地地不灵,所以才有了这篇教程,致力于帮助像我一样的小白少踩坑,我可以说是把能踩的坑都塌了一遍。 烧录…

Spring运维之boot项目多环境(yaml 多文件 proerties)及分组管理与开发控制

多环境开发(yaml文件版) 我们在自己的开发中是自己环境 测试 生产的环境都不同 多环境分为 两个步骤 设置环境 生产环境 开发环境 测试环境 手搓三个环境 设置应用环境 应用pro配置 # 应用环境 spring:profiles:active: pro--- # 设置环境 # 生产环境 spring:profiles: p…

MySQL 存储过程(一)

本篇主要介绍MySQL存储过程的相关内容 目录 一、什么是存储过程? 二、基本语法 创建存储过程 调用存储过程 查看存储过程 删除存储过程 三、变量 系统变量 用户自定义变量 局部变量 四、存储过程的参数 in out inout 一、什么是存储过程&#xff1f…

No module named _sqlite3解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

RedHat9 | 控制启动过程

1、Linux系统启动引导流程 加电自检(POST)寻找启动顺序(BIOS/UEFI)读取启动加载程序(MBR->Bootloader)加载内核与内存文件系统(kernel-initramfs)加载硬件及驱动(/lib/modules或/lib64/modules)初始化系…

9 -力扣高频 SQL 50 题(基础版)

9 - 上升的温度 -- 找出与之前(昨天的)日期相比温度更高的所有日期的 id -- DATEDIFF(2007-12-31,2007-12-30); # 1 -- DATEDIFF(2010-12-30,2010-12-31); # -1select w1.id from Weather w1, Weather w2 wheredatediff(w1.recordDate,w2.recordDat…

数组的详细介绍

数组是一组相同类型元素的集合,也就是说:数组至少包含两个及以上的元素,且元素类型相同。 数组包括一维数组和多维数组,其中二维数组最常见。下面我们一一介绍。 一维数组: 格式:type name [常量值]&…

微信短视频怎么收藏?成都鼎茂宏升文化传媒公司

微信短视频怎么收藏?一文教你轻松掌握 随着微信功能的不断升级,微信短视频已经成为我们日常生活中不可或缺的一部分。无论是朋友分享的生活点滴,还是公众号推送的精彩内容,短视频都以其直观、生动的形式,吸引着我们的…

Qt——控件

目录 概念 QWidget核心属性 enabled geometry WindowFrame的影响 windowTitle windowIcon qrc的使用 windowOpacity cursor font toolTip focusPolicy ​编辑 styleSheet 按钮类控件 PushButton RadioButton CheckBox 显示类控件 Label textFormat pixm…

什么牌子的洗地机好?高端旗舰洗地机,清洁力强的洗地机品牌

科技水平的不断进步,人们对生活环境的要求日益提高,洗地机作为一种高效,便捷的清洁设备,在家务清洁中,越来越受重视,洗地机不仅在吸尘、拖地和深度清洁等方面表现出色,可以帮助用户轻松应对各种…

Swagger教程:【Swagger】让你的API文档焕然一新!

Swagger(现称为OpenAPI Specification)是一种用于描述RESTful API接口的规范。它允许您以机器可读和人类可读的方式定义服务,使得开发、测试、维护和文档化API变得更加高效。下面整理了一个基础的Swagger教程,包括其重要组成部分和…

2021 hnust 湖科大 计组课设 包含multisim14连线文件,报告,指导书

2021 hnust 湖科大 计组课设 包含multisim14连线文件,报告,指导书 描述 hnust计组课设要用到的东西都在里面了 下载链接 https://pan.baidu.com/s/1tHooJmhkrwX47JCqsg379g?pwd1111

计网期末复习指南(五):运输层(可靠传输原理、TCP协议、UDP协议、端口)

前言:本系列文章旨在通过TCP/IP协议簇自下而上的梳理大致的知识点,从计算机网络体系结构出发到应用层,每一个协议层通过一篇文章进行总结,本系列正在持续更新中... 计网期末复习指南(一):计算机…

【计算机毕设】基于SpringBoot的民宿在线预定平台设计与实现 - 源码免费(私信领取)

免费领取源码 | 项目完整可运行 | v:chengn7890 诚招源码校园代理! 1. 研究目的 本研究旨在设计并实现一个基于SpringBoot的民宿在线预定平台。通过信息化手段提高民宿预定效率,方便用户查询房源、预定房间、在线支付和…

OBS+nginx+nginx-http-flv-module实现阿里云的推流和拉流

背景:需要将球机视频推送到阿里云nginx,使用网页和移动端进行播放,以前视频格式为RTMP,但是在网页上面播放RTMP格式需要安装flash插件,chrome浏览器不给安装,调研后发现可以使用nginx的模块nginx-http-flv-…

LlamaIndex介绍

LlamaIndex LangChain v0.2 教程分成以下部分: 1、入门 2、学习 3、用例 4、示例 5、高级 6、组件指南 RAG 用额外的信息来提高回答的质量。 分为 5个阶段: (1)loading 加载原始文件,LlamaHub 提供数百种连…

借助调试工具理解BLE协议_1.蓝牙简介和BLE工作流程

1.蓝牙简介 蓝牙是一种近距离无线通信技术,运行在2.4GHz免费频段,目前已大量应用于各种移动终端,物联网,健康医疗,智能家居等行业。蓝牙4.0以后的版本分为两种模式,单模蓝牙和双模蓝牙。 单模蓝牙&#xf…

聊聊测试的右移

这是鼎叔的第九十九篇原创文章。行业大牛和刚毕业的小白,都可以进来聊聊。 欢迎关注本公众号《敏捷测试转型》,星标收藏,大量原创思考文章陆续推出。本人新书《无测试组织-测试团队的敏捷转型》已出版(机械工业出版社&#xff09…

体育赛事直播系统开发源码搭建

随着体育产业的蓬勃发展,体育赛事直播已成为广大观众获取赛事信息的重要途径。为了满足观众日益增长的需求,开发一套专业的体育赛事直播系统成为当务之急。本文将围绕体育赛事直播系统开发源码搭建进行深入探讨,从技术选型、系统架构、安全防…