Torch初学(一)

  1. Tensor

    • 多维矩阵,可以使用LongStorage

       --- creation of a 4D-tensor 4x5x6x2
       z = torch.Tensor(4,5,6,2)
       --- for more dimensions, (here a 6D tensor) one can do:
       s = torch.LongStorage(6)
       s[1] = 4; s[2] = 5; s[3] = 6; s[4] = 2; s[5] = 7; s[6] = 3;
       x = torch.Tensor(s)
       --- The number of dimensions of a Tensor can be queried by nDimension() or dim()
       > x:nDimension()
         6
       --- Size of the i-th dimension is returned by size(i)(注意i从1开始). A LongStorage containing all the dimensions can be returned by size().
       > x:size()
         4
         5
         6
         2
         7
         3
         [torch.LongStorage of size 6]
    • 实际的数据存储在Storage中
    --- The actual data of a Tensor is contained into a Storage. It can be accessed using storage(). While the memory of a Tensor has to be contained in this unique Storage, it might not be contiguous: the first position used in the Storage is given by storageOffset() (starting at 1). And the jump needed to go from one element to another element in the i-th dimension is given by stride(i)
    
    > x = torch.Tensor(7,7,7) 
    --- accessing the element (3,4,5) can be done by
    > x[3][4][5]
    > 0 
    
    --- or equivalently (but slowly!)
    
    > x:storage()[x:storageOffset()+(3-1)*x:stride(1)+(4-1)*x:stride(2)+(5-1)*x:stride(3)]
    > 0 
    
    --- One could say that a Tensor is a particular way of viewing a Storage: a Storage only represents a chunk of memory, while the Tensor interprets this chunk of memory as having dimensions 
    
    > x = torch.Tensor(4,5)
    > s = x:storage()
    > for i=1,s:size() do -- fill up the Storage
      s[i] = i
    > x -- s is interpreted by x as a 2D matrix
      1   2   3   4   5
      6   7   8   9  10
     11  12  13  14  15
     16  17  18  19  20
    [torch.DoubleTensor of dimension 4x5] 
    
    --- Note also that in Torch7 elements in the same row [elements along the last dimension] are contiguous in memory for a matrix [tensor] 
    
    > x = torch.Tensor(4,5)
    > i = 0
    > x:apply(function()
      i = i + 1
      return i
      end)
    > x
      1   2   3   4   5
      6   7   8   9  10
     11  12  13  14  15
     16  17  18  19  20
    [torch.DoubleTensor of dimension 4x5]
    > x:stride()
     5
     1  -- element in the last dimension are contiguous!
    [torch.LongStorage of size 2]
    
    • 不同Tensor的种类:一般使用DoubleTensor和FloatTensor。用户可以使用torch.Tensor创建类型独立的torch脚本,当运行时选择想要的tensor type,如torch.setdefaulttensortype(‘torch.FloatTensor’)
    ByteTensor -- contains unsigned chars
    CharTensor -- contains signed chars
    ShortTensor -- contains shorts
    IntTensor -- contains ints
    LongTensor -- contains longs
    FloatTensor -- contains floats
    DoubleTensor -- contains doubles
    • 有效的内存管理
    --- All tensor operations in this class do not make any memory copy. All these methods transform the existing tensor, or return a new tensor referencing the same storage. This magical behavior is internally obtained by good usage of the stride() and storageOffset() 
    
    > x = torch.Tensor(5):zero()
    > x
    0
    0
    0
    0
    0
    [torch.DoubleTensor of dimension 5]
    > x:narrow(1, 2, 3):fill(1)
    > x
     0
     1
     1
     1
     0
    [torch.Tensor of dimension 5] 
    
    --- If you really need to copy a Tensor, you can use the copy() method. Or the convenience method
    > y = torch.Tensor(x:size()):copy(x)
    > y = x:clone()
    > y
     0
     1
     1
     1
     0
    [torch.Tensor of dimension 5] 
    • 创建Tensor
      • torch.Tensor()
      • torch.Tensor(tensor)
      • torch.Tensor(sz1 [,sz2 [,sz3 [,sz4]]]])
      • torch.Tensor(sizes, [strides])
      • torch.Tensor(storage, [storageOffset, sizes, [strides]])
      • torch.Tensor(storage, [storageOffset, sz1 [, st1 … [, sz4 [, st4]]]])
      • torch.Tensor(table)
    • 函数调用

      • Cloning
        • clone()
        • contiguous()
        • type(type)
        • typeAs(tensor)
        • isTensor(object)
        • byte(), char(), short(), int(), long(), float(), double()
      • Querying the size and structure
        • nDimension()
        • dim()
        • size(dim)
        • size()
        • self()
        • stride(dim)
        • stride()
        • storage()
        • isContiguous()
        • isSize(storage)
        • isSameSizeAs(tensor)
        • nElement()
        • storageOffset()
      • 访问元素
      x = torch.Tensor(3,3)
      i = 0; x:apply(function() i = i + 1; return i end)
      > x
       1  2  3
       4  5  6
       7  8  9
      [torch.DoubleTensor of dimension 3x3]
      
      > x[2] -- returns row 2
       4
       5
       6
      [torch.DoubleTensor of dimension 3]
      
      > x[2][3] -- returns row 2, column 3
      6
      
      > x[{2,3}] -- another way to return row 2, column 3
      6
      
      > x[torch.LongStorage{2,3}] -- yet another way to return row 2, column 3
      6
      
      > x[torch.le(x,3)] -- torch.le returns a ByteTensor that acts as a mask
       1
       2
       3
      [torch.DoubleTensor of dimension 3]
      • Referencing a tensor to an existing tensor or chunk of memory
        • set(tensor)
        • isSetTo(tensor)
        • set(storage, [storageOffset, sizes, [strides]])
        • set(storage, [storageOffset, sz1 [, st1 … [, sz4 [, st4]]]])
      • Copying and initializing
        • copy(tensor)
        • fill(value)
        • zero()
      • Resizing
        • resizeAs(tensor)
        • resize(sizes)
        • resize(sz1 [,sz2 [,sz3 [,sz4]]]])
      • Extracting sub-tensors
        • narrow(dim, index, size)
        • sub(dim1s, dim1e … [, dim4s [, dim4e]])
        • select(dim, index)
        • [{ dim1,dim2,… }] or [{ {dim1s,dim1e}, {dim2s,dim2e} }]
        • index(dim, index)
        • indexCopy(dim, index, tensor)
        • indexAdd(dim, index, tensor)
        • indexFill(dim, index, val)
        • gather(dim, index)
        • scatter(dim, index, src|val)
        • maskedSelect(mask)
        • maskedCopy(mask, tensor)
        • maskedFill(mask, val)
      • Search
        • nonzero(tensor)
      • Expanding/Replicating/Squeezing Tensors
        • expand([result,] sizes)
        • expandAs([result,] tensor)
        • repeatTensor([result,] sizes)
        • squeeze([dim])
      • Manipulating the tensor view
        • view([result,] tensor, sizes)
        • viewAs([result,] tensor, template)
        • transpose(dim1, dim2)
        • t()
        • permute(dim1, dim2, …, dimn)
        • unfold(dim, size, step)
      • Applying a function to a tensor
        • apply(function)
        • map(tensor, function(xs, xt))
        • map2(tensor1, tensor2, function(x, xt1, xt2))
      • Dividing a tensor into a table of tensors
        • split([result,] tensor, size, [dim])
        • chunk([result,] tensor, n, [dim])
      • LuaJIT FFI access
        • data(tensor, [asnumber])
        • cdata(tensor, [asnumber])
      • Reference counting
        • retain()
        • free()
  2. 数学操作
    • Construction or extraction functions
      • torch.cat( [res,] x_1, x_2, [dimension] )
      • torch.cat( [res,] {x_1, x_2, …}, [dimension] )
      • torch.diag([res,] x [,k])
      • torch.eye([res,] n [,m])
      • torch.histc([res,] x [,nbins, min_value, max_value])
      • torch.linspace([res,] x1, x2, [,n])
      • torch.logspace([res,] x1, x2, [,n])
      • torch.multinomial([res,], p, n, [,replacement])
      • torch.ones([res,] m [,n…])
      • torch.rand([res,] [gen,] m [,n…])
      • torch.randn([res,] [gen,] m [,n…])
      • torch.range([res,] x, y [,step])
      • torch.randperm([res,] [gen,] n)
      • torch.reshape([res,] x, m [,n…])
      • torch.tril([res,] x [,k])
      • torch.triu([res,] x, [,k])
      • torch.zeros([res,] x)
    • Element-wise Mathematical Operations
      • torch.abs([res,] x)
      • torch.sign([res,] x)
      • torch.acos([res,] x)
      • torch.asin([res,] x)
      • torch.atan([res,] x)
      • torch.ceil([res,] x)
      • torch.cos([res,] x)
      • torch.cosh([res,] x)
      • torch.exp([res,] x)
      • torch.floor([res,] x)
      • torch.log([res,] x)
      • torch.log1p([res,] x)
      • x:neg()
      • x:cinv()
      • torch.pow([res,] x, n)
      • torch.round([res,] x)
      • torch.sin([res,] x)
      • torch.sinh([res,] x)
      • torch.sqrt([res,] x)
      • torch.rsqrt([res,] x)
      • torch.tan([res,] x)
      • torch.tanh([res,] x)
      • torch.sigmoid([res,] x)
      • torch.trunc([res,] x)
      • torch.frac([res,] x)
    • Basic operations
      • equal([tensor1,] tensor2)
      • torch.add([res,] tensor, value)
      • torch.add([res,] tensor1, tensor2)
      • torch.add([res,] tensor1, value, tensor2)
      • tensor:csub(value)
      • tensor1:csub(tensor2)
      • torch.mul([res,] tensor1, value)
      • torch.clamp([res,] tensor, min_value, max_value)
      • torch.cmul([res,] tensor1, tensor2)
      • torch.cpow([res,] tensor1, tensor2)
      • torch.addcmul([res,] x [,value], tensor1, tensor2)
      • torch.div([res,] tensor, value)
      • torch.cdiv([res,] tensor1, tensor2)
      • torch.addcdiv([res,] x [,value], tensor1, tensor2)
      • torch.fmod([res,] tensor, value)
      • torch.remainder([res,] tensor, value)
      • torch.mod([res,] tensor, value)
      • torch.cfmod([res,] tensor1, tensor2)
      • torch.cremainder([res,] tensor1, tensor2)
      • torch.cmod([res,] tensor1, tensor2)
      • torch.dot(tensor1, tensor2)
      • torch.addmv([res,] [beta,] [v1,] vec1, [v2,] mat, vec2)
      • torch.addr([res,] [v1,] mat, [v2,] vec1, vec2)
      • torch.addmm([res,] [beta,] [v1,] M, [v2,] mat1, mat2)
      • torch.addbmm([res,] [v1,] M, [v2,] batch1, batch2)
      • torch.baddbmm([res,] [v1,] M, [v2,] batch1, batch2)
      • torch.mv([res,] mat, vec)
      • torch.mm([res,] mat1, mat2)
      • torch.bmm([res,] batch1, batch2)
      • torch.ger([res,] vec1, vec2)
      • torch.lerp([res,] a, b, weight)
    • Overloaded operators
      • Addition and subtraction
      • Negation
      • Multiplication
      • Division and Modulo (remainder)
    • Column or row-wise operations (dimension-wise operations)
      • torch.cross([res,] a, b [,n])
      • torch.cumprod([res,] x [,dim])
      • torch.cumsum([res,] x [,dim])
      • torch.max([resval, resind,] x [,dim])
      • torch.mean([res,] x [,dim])
      • torch.min([resval, resind,] x [,dim])
      • torch.cmax([res,] tensor1, tensor2)
      • torch.cmax([res,] tensor, value)
      • torch.cmin([res,] tensor1, tensor2)
      • torch.cmin([res,] tensor, value)
      • torch.median([resval, resind,] x [,dim])
      • torch.mode([resval, resind,] x [,dim])
      • torch.kthvalue([resval, resind,] x, k [,dim])
      • torch.prod([res,] x [,n])
      • torch.sort([resval, resind,] x [,d] [,flag])
      • torch.topk([resval, resind,] x, k, [,dim] [,dir] [,sort])
      • torch.std([res,] x, [,dim] [,flag])
      • torch.sum([res,] x)
      • torch.var([res,] x [,dim] [,flag])
    • Matrix-wide operations
      • torch.norm(x [,p] [,dim])
      • torch.renorm([res], x, p, dim, maxnorm)
      • torch.dist(x, y)
      • torch.numel(x)
      • torch.trace(x)
    • Convolution Operations
      • torch.conv2([res,] x, k, [, ‘F’ or ‘V’])
      • torch.xcorr2([res,] x, k, [, ‘F’ or ‘V’])
      • torch.conv3([res,] x, k, [, ‘F’ or ‘V’])
      • torch.xcorr3([res,] x, k, [, ‘F’ or ‘V’])
    • Eigenvalues, SVD, Linear System Solution
      • torch.gesv([resb, resa,] B, A)
      • torch.trtrs([resb, resa,] b, a [, ‘U’ or ‘L’] [, ‘N’ or ‘T’] [, ‘N’ or ‘U’])
      • torch.potrf([res,] A [, ‘U’ or ‘L’] )
      • torch.pstrf([res, piv, ] A [, ‘U’ or ‘L’] )
      • torch.potrs([res,] B, chol [, ‘U’ or ‘L’] )
      • torch.potri([res,] chol [, ‘U’ or ‘L’] )
      • torch.gels([resb, resa,] b, a)
      • torch.symeig([rese, resv,] a [, ‘N’ or ‘V’] [, ‘U’ or ‘L’])
      • torch.eig([rese, resv,] a [, ‘N’ or ‘V’])
      • torch.svd([resu, ress, resv,] a [, ‘S’ or ‘A’])
      • torch.inverse([res,] x)
      • torch.qr([q, r], x)
      • torch.geqrf([m, tau], a)
      • torch.orgqr([q], m, tau)
      • torch.ormqr([res], m, tau, mat [, ‘L’ or ‘R’] [, ‘N’ or ‘T’])
    • Logical Operations on Tensors
      • torch.lt(a, b)
      • torch.le(a, b)
      • torch.gt(a, b)
      • torch.ge(a, b)
      • torch.eq(a, b)
      • torch.ne(a, b)
      • torch.all(a)
      • torch.any(a)
  3. Storage接口
    • Constructors and Access Methods
      • torch.TYPEStorage([size [, ptr]])
      • torch.TYPEStorage(table)
      • torch.TYPEStorage(storage [, offset [, size]])
      • torch.TYPEStorage(filename [, shared [, size [, sharedMem]]])
      • self[index]
      • copy(storage)
      • fill(value)
      • resize(size)
      • size()
      • string(str)
      • string()
    • Reference counting methods
      • retain()
      • free()
  4. File
    • Read methods
    • Write methods
    • Serialization methods
      • readObject()
      • writeObject(object)
      • readString(format)
      • writeString(str)
    • General Access and Control Methods
      • ascii()
      • autoSpacing()
      • binary()
      • clearError()
      • close()
      • noAutoSpacing()
      • synchronize()
      • pedantic()
      • position()
      • quiet()
      • seek(position)
      • seekEnd()
    • File state query
      • hasError()
      • isQuiet()
      • isReadable()
      • isWritable()
      • isAutoSpacing()
      • referenced(ref)
      • isReferenced()
  5. Tester

    • Tester()
      • torch.Tester()
      • add(f, ‘name’)
      • run(testNames)
      • disable(testNames)
      • assert(condition [, message])
      • assertGeneralEq(got, expected [, tolerance] [, message])
      • eq(got, expected [, tolerance] [, message])
      • assertGeneralNe(got, unexpected [, tolerance] [, message])
      • ne(got, unexpected [, tolerance] [, message])
      • assertlt(a, b [, message])
      • assertgt(a, b [, message])
      • assertle(a, b [, message])
      • assertge(a, b [, message])
      • asserteq(a, b [, message])
      • assertne(a, b [, message])
      • assertalmosteq(a, b [, tolerance] [, message])
      • assertTensorEq(ta, tb [, tolerance] [, message])
      • assertTensorNe(ta, tb [, tolerance] [, message])
      • assertTableEq(ta, tb [, tolerance] [, message])
      • assertTableNe(ta, tb [, tolerance] [, message])
      • assertError(f [, message])
      • assertNoError(f [, message])
      • assertErrorMsg(f, errmsg [, message])
      • assertErrorPattern(f, errPattern [, message])
      • assertErrorObj(f, errcomp [, message])
      • setEarlyAbort(earlyAbort)
      • setRethrowErrors(rethrowErrors)
      • setSummaryOnly(summaryOnly)
    • TestSuite

      > test = torch.TestSuite()
      >
      > function test.myTest()
      >    -- ...
      > end
      >
      > -- ...
      >
      > function test.myTest()
      >    -- ...
      > end
      torch/TestSuite.lua:16: Test myTest is already defined.
    • 应用实例

      local mytest = torch.TestSuite()
      
      local tester = torch.Tester()
      
      function mytest.testA()
         local a = torch.Tensor{1, 2, 3}
         local b = torch.Tensor{1, 2, 4}
         tester:eq(a, b, "a and b should be equal")
      end
      
      function mytest.testB()
         local a = {2, torch.Tensor{1, 2, 2}}
         local b = {2, torch.Tensor{1, 2, 2.001}}
         tester:eq(a, b, 0.01, "a and b should be approximately equal")
      end
      
      function mytest.testC()
         local function myfunc()
            return "hello " .. world
         end
         tester:assertNoError(myfunc, "myfunc shouldn't give an error")
      end
      
      tester:add(mytest)
      tester:run()
      Running 3 tests
      1/3 testB ............................................................... [PASS]
      2/3 testA ............................................................... [FAIL]
      3/3 testC ............................................................... [FAIL]
      
      Completed 3 asserts in 3 tests with 2 failures and 0 errors
      --------------------------------------------------------------------------------
      
      testA
      a and b should be equal
      TensorEQ(==) violation: max diff=1, tolerance=0
      stack traceback:
              ./test.lua:8: in function <./test.lua:5>
      
      --------------------------------------------------------------------------------
      testC
      myfunc shouldn't give an error
      ERROR violation: err=./test.lua:19: attempt to concatenate global 'world' (a nil value)
      stack traceback:
              ./test.lua:21: in function <./test.lua:17>
      
      --------------------------------------------------------------------------------
      torch/torch/Tester.lua:383: An error was found while running tests!
      stack traceback:
              [C]: in function 'assert'
              torch/torch/Tester.lua:383: in function 'run'
              ./test.lua:25: in main chunk
  6. cmdline

    • addTime([name] [,format])
    • log(filename, parameter_table)
    • option(name, default, help)
    • parse(arg)
    • silent()
    • string(prefix, params, ignore)
    • text(string)
    • 应用
    cmd = torch.CmdLine()
    cmd:text()
    cmd:text()
    cmd:text('Training a simple network')
    cmd:text()
    cmd:text('Options')
    cmd:option('-seed',123,'initial random seed')
    cmd:option('-booloption',false,'boolean option')
    cmd:option('-stroption','mystring','string option')
    cmd:text()
    
    -- parse input params
    params = cmd:parse(arg)
    
    params.rundir = cmd:string('experiment', params, {dir=true})
    paths.mkdir(params.rundir)
    
    -- create log file
    cmd:log(params.rundir .. '/log', params)

    当运行th myscript.lua时会产生如下输出:

    [program started on Tue Jan 10 15:33:49 2012]
    [command line arguments]
    booloption  false
    seed    123
    rundir  experiment
    stroption   mystring
    [----------------------]
    booloption  false
    seed    123
    rundir  experiment
    stroption   mystring

    当运行th myscript.lua -seed 456 -stroption mycustomstring时会产生如下输出:

    [program started on Tue Jan 10 15:36:55 2012]
    [command line arguments]
    booloption  false
    seed    456
    rundir  experiment,seed=456,stroption=mycustomstring
    stroption   mycustomstring
    [----------------------]
    booloption  false
    seed    456
    rundir  experiment,seed=456,stroption=mycustomstring
    stroption   mycustomstring
  7. Random

    • Generator handling
    • Seed Handling
  8. Unility
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值